In [1]:
import os
import time
import numpy as np
import scipy as sp
import scipy.io as sio
import matplotlib.pyplot as plt
import matplotlib as mpl
import numpy.linalg as la
import glob
import nodepy.linear_multistep_method as lm
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
import tensorflow as tf
import tensorflow_datasets as tfds
#import haiku as hk
import jax
import flax.linen as nn
from flax.training import train_state  # Useful dataclass to keep train state
from typing import Sequence
import optax       
from jax import jacfwd, jacrev
2023-04-10 17:21:19.505146: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libcudart.so.11.0'; dlerror: libcudart.so.11.0: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-04-10 17:21:19.551061: E tensorflow/stream_executor/cuda/cuda_blas.cc:2981] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2023-04-10 17:21:21.254731: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-04-10 17:21:21.254993: W tensorflow/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory; LD_LIBRARY_PATH: /usr/local/nvidia/lib:/usr/local/nvidia/lib64
2023-04-10 17:21:21.255023: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.
In [2]:
import jax
jax.__version__
Out[2]:
'0.3.25'
In [3]:
import seaborn as sns
from matplotlib import rc

def set_style_sns():
    sns.set_context('paper')
    sns.set(font = 'serif')
    sns.set(font_scale = 1.3)
    
    sns.set_style('white', {
        'font.family': 'serif',
        'font.serif': ['Time', 'Palatino', 'serif'],
        'lines.markersize': 10
    })
    
    
plt.rcParams.update({'font.size':16})
set_style_sns()

plt.rcParams.update({'text.usetex': False})
In [4]:
import scipy.io as sio
data = sio.loadmat('DOF6_ROM_ml-20230324-124511cubic_conserve-1-dof0064')
In [5]:
Mhat = data['Mhat']
Khat = data['Khat']
Chat = np.zeros((6,6)) #data['Chat']
t = data['t']
dt = data['dt']

q_ = data['Qr']
dq_ = data['Qdr']
ddq_ = data['Qddr'] 

x = data['Q']
dx = data['Qd']
ddx = data['Qdd'] 

V = data['v']

q = np.dot(V.T, x)
dq = np.dot(V.T, dx)
ddq = np.dot(V.T, ddx)

Qext = np.zeros(q.shape)
Qrr = data['Qrr']

x_ = np.dot(V, q)
dx_ = np.dot(V, dq)
ddx_ = np.dot(V, ddq)

Ntt = data['Ntt'][0][0]

print(q.shape, dq.shape, ddq.shape, q_.shape)
(6, 16384) (6, 16384) (6, 16384) (6, 16384)
In [6]:
Ntt
Out[6]:
8192
In [7]:
print(Khat.shape, Ntt)
(6, 6) 8192
In [8]:
plt.figure(figsize = (10,4))

sample = 10

plt.plot(t[0], x.T[:, sample], 'b', alpha = 0.5)
plt.plot(t[0], Qrr.T[2:, sample], 'r', alpha = 0.2)

plt.ylabel('Displacement')
plt.xlabel('Time, s')

plt.show()
In [9]:
plt.figure(figsize = (10,4))

sample = np.array([33, 10])

plt.plot(t[0], ddx.T[:, sample])
plt.plot(t[0],ddx_.T[:, sample], 'k--')

plt.xlim([0,18])

plt.show()
In [10]:
plt.figure(figsize = (14,6))

#Ntt = q.shape[1]

plt.subplot(4,1,1)
plt.plot(t[0,:Ntt], q[:,:Ntt].T)

plt.subplot(4,1,2)
plt.plot(t[0,:Ntt], dq[:,:Ntt].T)

plt.subplot(4,1,3)
plt.plot(t[0,:Ntt], ddq[:,:Ntt].T)

plt.tight_layout()
plt.show()
In [11]:
Ntt
Out[11]:
8192
In [54]:
nDof = q.shape[0]

test_factor = 2

start = 0
subset = Ntt*test_factor #8200 #45000 #12000
ixend = Ntt*test_factor + 200 #8600 #6500 #45500 #12500
# Choose training data - using linear
QnormTrainData = np.concatenate([q.T[start:,:,np.newaxis], dq.T[start:,:,np.newaxis], ddq.T[start:,:,np.newaxis], Qext.T[start:,:,np.newaxis]], axis = 2).reshape((-1,nDof,4))
#xmax_for_norm = np.max(np.max(normTrainData, axis = 0), axis = 0)
#normTrainData = normTrainData/xmax
QnormTrainData = QnormTrainData[:subset, :, :]
q_train_dataset = tf.data.Dataset.from_tensor_slices(tf.cast(QnormTrainData,dtype=tf.float32))
#train_dataset = train_dataset.batch(256)

QnormTestData =  np.concatenate([q.T[start:,:,np.newaxis], dq.T[start:,:,np.newaxis], ddq.T[start:,:,np.newaxis], Qext.T[start:,:,np.newaxis]], axis = 2).reshape((-1,nDof,4))
QnormTestData = QnormTestData[subset:ixend, :, :]
q_test_dataset = tf.data.Dataset.from_tensor_slices(tf.cast(QnormTestData,dtype=tf.float32))
#test_dataset = test_dataset.batch(512)
In [55]:
# Choose training data - using linear
Fext = np.zeros(x.shape)
full_nDof = x.shape[0]
XnormTrainData = np.concatenate([x.T[start:,:,np.newaxis], dx.T[start:,:,np.newaxis], ddx.T[start:,:,np.newaxis], Fext.T[start:,:,np.newaxis]], axis = 2).reshape((-1,full_nDof,4))
#xmax_for_norm = np.max(np.max(normTrainData, axis = 0), axis = 0)
#normTrainData = normTrainData/xmax
XnormTrainData = XnormTrainData[:subset, :, :]
x_train_dataset = tf.data.Dataset.from_tensor_slices(tf.cast(XnormTrainData,dtype=tf.float32))
#train_dataset = train_dataset.batch(256)

XnormTestData =  np.concatenate([x.T[start:,:,np.newaxis], dx.T[start:,:,np.newaxis], ddx.T[start:,:,np.newaxis], Fext.T[start:,:,np.newaxis]], axis = 2).reshape((-1,full_nDof,4))
XnormTestData = XnormTestData[subset:ixend, :, :]
x_test_dataset = tf.data.Dataset.from_tensor_slices(tf.cast(XnormTestData,dtype=tf.float32))
#test_dataset = test_dataset.batch(512)
In [15]:
print(XnormTrainData.shape, XnormTestData.shape)
print(QnormTrainData.shape, QnormTestData.shape)
(8192, 64, 4) (200, 64, 4)
(8192, 6, 4) (200, 6, 4)
In [16]:
plt.figure(figsize = (10,6))

plt.subplot(2,1,1)
plt.plot(QnormTrainData[:,:,1])
plt.plot(np.diff(QnormTrainData[:,:,0], axis = 0)/dt, 'k--')

plt.subplot(2,1,2)
plt.plot(QnormTrainData[:,:,2])
plt.plot(np.diff(QnormTrainData[:,:,1], axis = 0)/dt, 'k--')

plt.show()
In [17]:
plt.figure(figsize = (10,6))

plt.subplot(2,1,1)
plt.plot(XnormTrainData[:,:,1])
plt.plot(np.diff(XnormTrainData[:,:,0], axis = 0)/dt, 'k--')
plt.xlim([1000,1200])

plt.subplot(2,1,2)
plt.plot(XnormTrainData[:,:,2])
plt.plot(np.diff(XnormTrainData[:,:,1], axis = 0)/dt, 'k--')
plt.xlim([1000,1200])

plt.show()
In [18]:
NLAYERS = 4
NNODES = 6 #10

class MLP(nn.Module):
  features: Sequence[int]
  
  @nn.compact
  def __call__(self, xx):
    for feat in self.features[:-1]:
      xx = nn.tanh(nn.Dense(feat, kernel_init=nn.initializers.glorot_uniform())(xx))
      #xx = nn.Dropout(0.3)(xx, deterministic=DETERMINISTIC)
    xx = nn.Dense(self.features[-1], use_bias = True)(xx)

    return xx

class MLP_constant(nn.Module):
  features: Sequence[int]
  
  @nn.compact
  def __call__(self, xx):

    xx = nn.Dense(self.features[0], use_bias = False)(xx)

    return xx

class MLP_nl(nn.Module):
  features: Sequence[int]
  
  @nn.compact
  def __call__(self, xx):
    xx_poly = polynomialFeatures(jnp.reshape(xx, ((-1,))), degree = 3, interaction_only = False, include_bias = False).reshape((-1,)) #jnp.reshape(jnp.power(jnp.reshape(xx, ((-1,1))), jnp.array([1,2,3,4])), dim*4) #
    
    for feat in self.features[:-1]:
      xx = nn.swish(nn.Dense(feat, kernel_init=nn.initializers.glorot_uniform(), use_bias = False)(xx))

    xx = jnp.concatenate([xx_poly, xx]) #skip connection   
    xx = nn.Dense(self.features[-1], use_bias = False)(xx)

    return xx

class MLP_params(nn.Module):
  features: Sequence[int]
  sizes_2 = [10,15]
  
  @nn.compact
  def __call__(self, xx, in_params):
    in_params = jnp.reshape(in_params, ((-1,)))
    xx = jnp.concatenate([xx, in_params])
    for feat in self.features[:-1]:
      xx = nn.tanh(nn.Dense(feat, kernel_init=nn.initializers.glorot_uniform())(xx))
      #xx = nn.Dropout(0.3)(xx, deterministic=DETERMINISTIC)
    xx = nn.Dense(self.features[-1], use_bias = True)(xx)
    
    #for feat in self.sizes_2:
    #    in_params = nn.swish(nn.Dense(feat, kernel_init=nn.initializers.glorot_uniform(), use_bias = True)(in_params))
    
    #in_params = nn.Dense(self.features[-1], use_bias = True)(in_params)
    
    return xx


class CNN(nn.Module):
  features = [2, 10, 10, 10, 1]

  @nn.compact
  def __call__(self, xx):
    for feat in self.features[:-1]:
      xx = nn.Conv(features=feat, kernel_size = (2,2), padding = 'same')(xx)
      xx = nn.swish(xx)
    xx = nn.Conv(features=self.features[-1], kernel_size = (2,2), padding = 'same')(xx)

    return xx

class CNN1D(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, xx):
    x = xx[0,:LATENT,:]
    dx = xx[0,LATENT:,:]
    
    if self.features[0] == 0:
        xx = x
        xx = jnp.reshape(xx, ((1, LATENT, 1)))
    elif self.features[0] == 1:
        xx = dx
        xx = jnp.reshape(xx, ((1, LATENT, 1)))

    xx_poly = jnp.reshape(jnp.power(jnp.reshape(xx, ((-1,1))), jnp.array([1,2])), LATENT*2)
    xx = nn.Conv(features=4, kernel_size = (1,2), padding = 'valid', use_bias = False)(xx)
    xx = polypool(xx)

    xx = jnp.concatenate([xx_poly, xx]) #skip connection 
    xx = nn.Dense(1, use_bias = True)(xx)

    return xx

class MLP_struct(nn.Module):
  features: Sequence[int]

  @nn.compact
  def __call__(self, xx):
    x = xx[0,:LATENT,:]
    dx = xx[0,LATENT:,:]


    term1 = x[0]**2
    term2 = x[1]**2
    term3 = (x[1] - x[0])**2
    
    term4 = dx[0]**2
    term5 = dx[1]**2
    term6 = (dx[1] - dx[0])**2

    
    
    xvec = jnp.array([term1,term2,term3,term4,term5,term6]).reshape((-1,))

    xvec = xvec[jnp.array(self.features)]

    
    xx = nn.Dense(1, use_bias = False)(xvec)

    return xx

class MLP_poly(nn.Module):
  features: Sequence[int]
  degree: Sequence[int]
  sizes = [6]

  @nn.compact
  def __call__(self, xx):
    x = xx[0,:LATENT,:]
    dx = xx[0,LATENT:,:]
    
    if self.features[0] == 0:
        xx = x
        xx = jnp.reshape(xx, ((LATENT,)))
        dim = LATENT
    elif self.features[0] == 1:
        xx = dx
        xx = jnp.reshape(xx, ((LATENT,)))
        dim = LATENT
    else:
        xx = jnp.reshape(xx, ((LATENT*2,)))
        dim = LATENT*2
        
    xx_poly = polynomialFeatures(jnp.reshape(xx, ((-1,))), degree = self.degree[0], interaction_only = False, include_bias = False).reshape((-1,)) #jnp.reshape(jnp.power(jnp.reshape(xx, ((-1,1))), jnp.array([1,2,3,4])), dim*4) #
    xx_sin = jnp.reshape(jnp.sin(xx), dim)
    xx_cos = jnp.reshape(jnp.cos(xx), dim)
    
    for feat in self.sizes[:-1]:
      xx = nn.swish(nn.Dense(feat, kernel_init=nn.initializers.glorot_uniform(), use_bias = False)(xx))

    xx = nn.swish(nn.Dense(self.sizes[-1], use_bias = False, kernel_init=nn.initializers.glorot_uniform())(xx))
    xx = xx_poly #jnp.concatenate([xx_poly, xx]) #skip connection 
    xx = nn.Dense(1, use_bias = True)(xx)    
    
    return xx

class MLP_poly_fom(nn.Module):
  features: Sequence[int]
  degree: Sequence[int]
  sizes = [30,20,10,6]

  @nn.compact
  def __call__(self, xx):
    #x = xx[0,:LATENT,:]
    #dx = xx[0,LATENT:,:]
    
    xx = jnp.reshape(xx, ((-1,)))
    
    #xx_poly = polynomialFeatures(jnp.reshape(xx, ((-1,))), degree = self.degree[0], interaction_only = False, include_bias = False).reshape((-1,)) #jnp.reshape(jnp.power(jnp.reshape(xx, ((-1,1))), jnp.array([1,2,3,4])), dim*4) #
    #xx_sin = jnp.reshape(jnp.sin(xx), dim)
    #xx_cos = jnp.reshape(jnp.cos(xx), dim)
    
    for feat in self.sizes[:-1]:
      xx = nn.swish(nn.Dense(feat, kernel_init=nn.initializers.glorot_uniform(), use_bias = False)(xx))

    xx = nn.swish(nn.Dense(self.sizes[-1], use_bias = False, kernel_init=nn.initializers.glorot_uniform())(xx))
    #xx = xx_poly #jnp.concatenate([xx_poly, xx]) #skip connection 
    xx = nn.Dense(1, use_bias = True)(xx)    
    
    return xx


class MLP_poly_latent(nn.Module):
  features: Sequence[int]
  nlayers = NLAYERS
  nodes = NNODES
  sizes = list(nlayers*[nodes])

  @nn.compact
  def __call__(self, xx):
    x = xx[0,:LATENT,:]
    dx = xx[0,LATENT:,:]

    
    if self.features[0] == 0:
        xx = x
        xx = jnp.reshape(xx, ((LATENT,)))
        dim = LATENT
    elif self.features[0] == 1:
        xx = dx
        xx = jnp.reshape(xx, ((LATENT,)))
        dim = LATENT
    else:
        xx = jnp.reshape(xx, ((LATENT*2,)))
        dim = LATENT*2

        
    xx_poly = jnp.reshape(jnp.power(jnp.reshape(xx, ((-1,1))), jnp.array([1,2,3,4])), dim*4) # polynomialFeatures(jnp.reshape(xx, ((-1,))), degree = 4, interaction_only = False, include_bias = False).reshape((-1,))
    xx_sin = jnp.reshape(jnp.sin(xx), dim)
    xx_cos = jnp.reshape(jnp.cos(xx), dim)

    for feat in self.sizes[:-1]:
      xx = nn.swish(nn.Dense(feat, kernel_init=nn.initializers.glorot_uniform(), use_bias = False)(xx))

    print(xx.shape, xx_sin.shape, xx_poly.shape)
    xx = nn.swish(nn.Dense(self.sizes[-1], use_bias = False, kernel_init=nn.initializers.glorot_uniform())(xx))
    xx = jnp.concatenate([xx_poly, xx]) #skip connection 
    xx = nn.Dense(LATENT, use_bias = True)(xx)  

    
    return xx


def hessian(f, argnums):
    return jacfwd(jacrev(f, argnums = argnums), argnums = argnums)

def polypool(x):
    #x = jnp.sum(x, axis = 1)
    #x = jnp.reshape(x, (5,))
    x = jnp.power(x, jnp.array([1,2,3,4]))
    x = jnp.reshape(x, ((-1,)))
    return x


def polynomialFeatures( X, degree = 2, interaction_only = False, include_bias = True ) :
    features = X
    prev_chunk = X
    indices = list( range( len( X ) ) )

    for d in range( 1, degree ) :
        # Create a new chunk of features for the degree d:
        new_chunk = []
        # Multiply each component with the products from the previous lower degree:
        for i, v in enumerate( X[:-d] if interaction_only else X ) :
            # Store the index where to start multiplying with the current component
            # at the next degree up:
            next_index = len( new_chunk )
            for coef in prev_chunk[indices[i+( 1 if interaction_only else 0 )]:] :
                new_chunk.append( v*coef )
            indices[i] = next_index
        # Extend the feature vector with the new chunk of features from the degree d:
        features = jnp.append( features, jnp.array(new_chunk))
        prev_chunk = new_chunk

    if include_bias :
        features = jnp.insert( features, 0, 1 )

    return features
In [19]:
LATENT = nDof
ORIGINAL_DIM = 64
MSTEPS = 2
METHOD = lm.Adams_Bashforth(MSTEPS)
ALPHA = np.float32(-METHOD.alpha[::-1])
BETA = np.float32(METHOD.beta[::-1])
QMAX = np.max(np.max(QnormTrainData[:,:,:], axis = 0), axis = 0)
XMAX = np.max(np.max(XnormTrainData[:,:,:], axis = 0), axis = 0)
DETERMINISTIC = False
print(XMAX.shape)

MHAT = Mhat
KHAT = Khat
CHAT = Chat
VMAT = jnp.array(V)

DT = dt

class CanonicalTransformer(nn.Module):
  def setup(self):
    self.E_layers = MLP([10,10,25,LATENT*LATENT])
    self.latent = LATENT
    self.original = ORIGINAL_DIM
    

  def __call__(self, x, dx, ddx):
    x = jnp.reshape(x, (1, self.latent))
    dx = jnp.reshape(dx, (1, self.latent))
    ddx = jnp.reshape(ddx, (1, self.latent))
    in_x = jnp.reshape(x, (self.latent)) #jnp.reshape(jnp.concatenate([x, dx, ddx], axis = 1), (self.original*3))
    
    T_dummy = jnp.reshape(self.E_layers(in_x), (self.latent,self.latent))
    T = jnp.eye(self.latent) #jnp.reshape(self.E_layers(in_x), (self.latent,self.latent))
    Q = jnp.dot(T, 
                jnp.reshape(in_x, self.latent))
    Q = jnp.reshape(Q, (self.latent,))
    
    return Q

  def get_T(self, x):
    x = jnp.reshape(x, (1, self.latent))
    #dx = jnp.reshape(dx, (1, self.original))
    #ddx = jnp.reshape(ddx, (1, self.original))
    in_x = jnp.reshape(x, (self.latent))
    
    T = jnp.eye(self.latent)
    return T   
    

    
class InverseCanonicalTransformer(nn.Module):
  def setup(self):
    self.latent = LATENT
    self.original = ORIGINAL_DIM
    self.D_layers = MLP_nl([6,6,10,self.original])
    
  def __call__(self, x, dx, ddx):
    x = jnp.reshape(x, (1, self.latent, 1))
    dx = jnp.reshape(dx, (1, self.latent, 1))
    ddx = jnp.reshape(ddx, (1, self.latent, 1))
    in_x = jnp.reshape(x, (self.latent)) #jnp.reshape(jnp.concatenate([x, dx, ddx], axis = 1), (self.latent*3))
    
    
    X_nl= jnp.reshape(self.D_layers(in_x), (self.original,))
    T = VMAT
    X_lin = jnp.dot(T, 
                jnp.reshape(in_x, self.latent))
    
    #X_nl = jnp.dot(T_nl, 
    #            jnp.reshape(x_poly, nl_dims))
    
    X = jnp.reshape(X_lin, (self.original,)) +  jnp.reshape(X_nl, (self.original,))
    
    return X

  def get_T(self, x, dx, ddx):
    x = jnp.reshape(x, (1, self.latent, 1))
    dx = jnp.reshape(dx, (1, self.latent, 1))
    ddx = jnp.reshape(ddx, (1, self.latent, 1))
    in_x = jnp.reshape(x, (self.latent))
    
    x_poly = polynomialFeatures(jnp.reshape(x, ((-1,))), degree = 3, interaction_only = False, include_bias = False).reshape((-1,))
    x_poly = x_poly[self.latent:]
    nl_dims = len(x_poly)
    
    T = jnp.reshape(self.D_layers(in_x), (self.original, nl_dims))
    return T   


class Kinetic(nn.Module):
  def setup(self):
    self.T_layers = MLP_poly([1], degree = [2])
    self.latent = LATENT

  def __call__(self, x, dx):
    x = jnp.reshape(x, (1, self.latent, 1))
    dx = jnp.reshape(dx, (1, self.latent, 1))
    in_x = jnp.concatenate([x, dx], axis = 1)
    
    T = self.T_nl(x, dx)
    
    return T

  def T_lin(self, x, dx):
    dx = jnp.reshape(dx, (1, self.latent))
    T_lin = jnp.dot(dx, jnp.dot(MHAT, dx.T))
    return 0.5*jnp.reshape(T_lin, (-1,))

  def T_nl(self, x, dx):
    in_x = jnp.concatenate([x, dx], axis = 1)
    T = jnp.reshape(self.T_layers(in_x), (-1,))
    return T


class Potential(nn.Module):
  def setup(self):
    self.V_layers = MLP_poly_fom(features = [0], degree = [4])
    self.latent = LATENT
    self.original_dim = ORIGINAL_DIM

  def __call__(self, x, dx):
    x = jnp.reshape(x, (1, self.original_dim, 1))
    dx = jnp.reshape(dx, (1, self.original_dim, 1))
    in_x = jnp.concatenate([x, dx], axis = 1)
    
    V = self.T_nl(x, dx)
    
    return V

  def T_lin(self, x, dx):
    x = jnp.reshape(x, (1, self.latent))
    T_lin = 0.5*jnp.dot(x, jnp.dot(KHAT, x.T))
    return jnp.reshape(T_lin, (-1,))

  def T_nl(self, x, dx):
    in_x = jnp.concatenate([x, dx], axis = 1)
    T = jnp.reshape(self.V_layers(x), (-1,))  # only pass displacements
    return T

class Rayleigh(nn.Module):
  def setup(self):
    self.R_layers = MLP_poly([1], degree = [2])
    self.latent = LATENT

  def __call__(self, x, dx):
    x = jnp.reshape(x, (1, self.latent, 1))
    dx = jnp.reshape(dx, (1, self.latent, 1))
    in_x = jnp.concatenate([x, dx], axis = 1)
    
    R = self.T_nl(x, dx)
    
    return R

  def T_lin(self, x, dx):
    dx = jnp.reshape(dx, (1, self.latent))
    T_lin = 0.5*jnp.dot(dx, jnp.dot(CHAT, dx.T))
    return jnp.reshape(T_lin, (-1,))

  def T_nl(self, x, dx):
    in_x = jnp.concatenate([x, dx], axis = 1)
    T = jnp.reshape(self.R_layers(in_x), (-1,))
    return T


class Qnn(nn.Module):
  def setup(self):
    self.Qnn_layers = MLP_poly_latent([2])
    self.latent = LATENT

  def __call__(self, x, dx):
    x = jnp.reshape(x, (1, self.latent, 1))
    dx = jnp.reshape(dx, (1, self.latent, 1))
    in_x = jnp.concatenate([x, dx], axis = 1)
    
    V = jnp.reshape(self.Qnn_layers(in_x), (-1,))
    
    return V*0.0

class Minv_CNN(nn.Module):
  def setup(self):
    self.cnn_layers = CNN()
    self.latent = LATENT

  def __call__(self, Mmat):
    Mmat_re = jnp.reshape(Mmat, (self.latent, self.latent, 1))
    Minv_ = jnp.reshape(self.cnn_layers(Mmat_re), (self.latent, self.latent))

    return Minv_

class AEROM(nn.Module):
  def setup(self):
    self.T_layers = Kinetic()
    self.V_layers = Potential()
    self.R_layers = Rayleigh()
    self.Qnn_layers = Qnn()
    
    self.encoderNet_q = CanonicalTransformer()
    #self.encoderNet_dq = CanonicalTransformer()
    #self.encoderNet_ddq = CanonicalTransformer()
    
    self.decoderNet_q = InverseCanonicalTransformer()
    #self.decoderNet_dq = InverseCanonicalTransformer()
    #self.decoderNet_ddq = InverseCanonicalTransformer()
    self.latent = LATENT

  def __call__(self, x, dx, ddx):
    #qdot = self.encoderNet_dq(x, dx, ddx)
    #qdotdot = self.encoderNet_ddq(x, dx, ddx)
    T = self.T_layers(x,x)
    
    R = self.R_layers(x,x)
    q_ = self.encoderNet_q(x, x, x) #dummy call
    x_ = self.decoderNet_q(x, x, x) #dummy call
    V = self.V_layers(x_,x_)
    Qnn = self.Qnn_layers(x,x)
    #dx_ = self.decoderNet_dq(q, qdot, qdotdot)
    #ddx_ = self.decoderNet_ddq(q, qdot, qdotdot)
    return T


def get_quadratic_energy_jnp(K,x):
  xi = x.reshape((1,LATENT))
  #print(K.shape, xi.shape)
  E = 1/2.*jnp.dot(xi, jnp.dot(K, xi.T))

  return jnp.reshape(E, (-1,))


def Lagrangian(params, x, dx):
  T = eval_kinetic(params, x, dx) #Kinetic().apply({'params': params[0]}, x, dx)
  V = eval_potential(params, x, dx) #Potential().apply({'params': params[1]}, x, dx)
  return T-V

def eval_kinetic(params,x, dx):
  T = Kinetic().apply({'params': params[0]}, x, dx)
  B = CanonicalTransformer().apply({'params': params[4]}, x, method = CanonicalTransformer().get_T)
  A = jnp.linalg.pinv(jnp.reshape(B, (LATENT,LATENT)))
  Mtrans = jnp.dot(jnp.transpose(A), jnp.dot(MHAT, A))
  print(B.shape, A.shape, MHAT.shape, dx.shape, Mtrans.shape)
  T_lin = 0.5*jnp.dot(dx, jnp.dot(Mtrans, dx.T))

  return T*0.0 + T_lin

def eval_potential(params, x, dx):
  
  xfom = jnp.dot(VMAT, jnp.reshape(x, (LATENT, 1)))
  dxfom = jnp.dot(VMAT, jnp.reshape(dx, (LATENT, 1)))
    
  V = Potential().apply({'params': params[1]}, xfom, dxfom)
  B = CanonicalTransformer().apply({'params': params[4]}, x, method = CanonicalTransformer().get_T)
  A = jnp.linalg.pinv(jnp.reshape(B, (LATENT,LATENT)))
  Ktrans = jnp.dot(jnp.transpose(A), jnp.dot(KHAT, A))
  V_lin = 0.5*jnp.dot(x, jnp.dot(Ktrans, x.T))

  return V + V_lin

def eval_rayleigh(params, x, dx):
  R = Rayleigh().apply({'params': params[2]}, x, dx)
  B = CanonicalTransformer().apply({'params': params[4]}, x, method = CanonicalTransformer().get_T)
  A = jnp.linalg.pinv(jnp.reshape(B, (LATENT,LATENT)))
  Ctrans = jnp.dot(jnp.transpose(A), jnp.dot(CHAT, A))
  R_lin = 0.5*jnp.dot(dx, jnp.dot(Ctrans, dx.T))
  return (R + R_lin)*0.0


def eval_kinetic_lin(params,x, dx):
  B = CanonicalTransformer().apply({'params': params[4]}, x, method = CanonicalTransformer().get_T)
  A = jnp.linalg.pinv(jnp.reshape(B, (LATENT,LATENT)))
  Mtrans = jnp.dot(jnp.transpose(A), jnp.dot(MHAT, A))
  T_lin = 0.5*jnp.dot(dx, jnp.dot(Mtrans, dx.T))
  return T_lin

def eval_potential_lin(params, x, dx):
  B = CanonicalTransformer().apply({'params': params[4]}, x, method = CanonicalTransformer().get_T)
  A = jnp.linalg.pinv(jnp.reshape(B, (LATENT,LATENT)))
  Ktrans = jnp.dot(jnp.transpose(A), jnp.dot(KHAT, A))
  V_lin = 0.5*jnp.dot(x, jnp.dot(Ktrans, x.T))
  return V_lin

def eval_rayleigh_lin(params, x, dx):
  B = CanonicalTransformer().apply({'params': params[4]}, x, method = CanonicalTransformer().get_T)
  A = jnp.linalg.pinv(jnp.reshape(B, (LATENT,LATENT)))
  Ctrans = jnp.dot(jnp.transpose(A), jnp.dot(CHAT, A))
  R_lin = 0.5*jnp.dot(dx, jnp.dot(Ctrans, dx.T))
  return R_lin

def eval_kinetic_nl(params,x, dx):
  T = Kinetic().apply({'params': params[0]}, x, dx)
  return T 

def eval_potential_nl(params, x, dx):
  xfom = jnp.dot(VMAT, jnp.reshape(x, (LATENT, 1)))
  dxfom = jnp.dot(VMAT, jnp.reshape(dx, (LATENT, 1)))
  V = Potential().apply({'params': params[1]}, xfom, dxfom)
  return V 

def eval_rayleigh_nl(params, x, dx):
  R = Rayleigh().apply({'params': params[2]}, x, dx)
  return R


def encode_q(E_params, x, dx, ddx):
  R = CanonicalTransformer().apply({'params': E_params[0]}, x, dx, ddx)
  return R

def eval_Qnn(params, x, dx):
  R = Qnn().apply({'params': params[3]}, x, dx)
  return R


def encode_dq(E_params, x, dx, ddx):
  dx_dq = jax.jacfwd(encode_q, argnums = 1)(E_params, x, dx, ddx) #dx/dq
    
  dx = jnp.reshape(dx, (LATENT,))
  dX = jnp.matmul(dx_dq, dx)
  return jnp.reshape(dX, (LATENT,))

def encode_ddq(E_params, x, dx, ddx):
  ddx_ddq = jax.jacfwd(encode_dq, argnums = 2)(E_params, x, dx, ddx) #dxdot/dqdot
  ddx_dq = jax.jacfwd(encode_dq, argnums = 1)(E_params, x, dx, ddx) #dxdot/dq

  ddx = jnp.reshape(ddx, (LATENT,))
  dx = jnp.reshape(dx, (LATENT,))
    
  ddX = jnp.matmul(ddx_ddq, ddx) + jnp.matmul(ddx_dq, dx)
  return jnp.reshape(ddX, (LATENT,))

def decode_q(D_params, x, dx, ddx):
  R = InverseCanonicalTransformer().apply({'params': D_params[0]}, x, dx, ddx)
  return R

def get_T_x(E_params, x, dx, ddx, comp=0):
  T = InverseCanonicalTransformer().apply({'params': E_params[comp]}, x, dx, ddx, method = InverseCanonicalTransformer().get_T)
  return T

def decode_dq(D_params, x, dx, ddx):
  dx_dq = jax.jacfwd(decode_q, argnums = 1)(D_params, x, dx, ddx) #dx/dq
    
  dx = jnp.reshape(dx, (LATENT,))
  dX = jnp.matmul(dx_dq, dx)
  return jnp.reshape(dX, (ORIGINAL_DIM,))

def decode_ddq(D_params, x, dx, ddx):
  ddx_ddq = jax.jacfwd(decode_dq, argnums = 2)(D_params, x, dx, ddx) #dxdot/dqdot
  ddx_dq = jax.jacfwd(decode_dq, argnums = 1)(D_params, x, dx, ddx) #dxdot/dq

  ddx = jnp.reshape(ddx, (LATENT,))
  dx = jnp.reshape(dx, (LATENT,))
    
  ddX = jnp.matmul(ddx_ddq, ddx) + jnp.matmul(ddx_dq, dx)
  return jnp.reshape(ddX, (ORIGINAL_DIM,))

def encoder(E_params, x, dx, ddx):
  q = encode_q(E_params, x, dx, ddx)
  dq = encode_dq(E_params, x, dx, ddx)
  ddq = encode_ddq(E_params, x, dx, ddx)

  return q, dq, ddq

def decoder(D_params, x, dx, ddx):
  q = decode_q(D_params, x, dx, ddx)
  dq = decode_dq(D_params, x, dx, ddx)
  ddq = decode_ddq(D_params, x, dx, ddx)

  return q, dq, ddq

def init_params(data):
  rng = random.PRNGKey(0)
  rng, key = random.split(rng)
  pt = Kinetic().init(key, data[:1,:,0], data[:1,:,1])['params']
  pv = Potential().init(key, data[:1,:,0], data[:1,:,1])['params']

  pM = Minv_CNN().init(key, jnp.zeros((1, LATENT, LATENT, 1)))['params']

  return [pt, pv, pM]

def momentum(params, x, dx):
  latent = LATENT
  dx = jnp.reshape(dx, (-1, latent))
  x = jnp.reshape(x, (-1, latent))

  p = jax.jacfwd(Lagrangian, argnums = 2)(params, x, dx)
  p = jnp.reshape(p, (latent,))
  return p 

def Hamiltonian(params, q, qdot, p):
  latent = LATENT
  qdot = jnp.reshape(qdot, (-1, latent))
  q = jnp.reshape(q, (-1, latent))

  L = Lagrangian(params,q,qdot)

  qdotp = jnp.dot(qdot, jnp.reshape(p, (latent, -1)))

  qdotp = jnp.reshape(qdotp, (-1,))

  H = qdotp - L
            
  return H

def canonical(params, q, qdot, p, Qext):
  latent = LATENT
  dH_dp = jax.jacfwd(Hamiltonian, argnums = 3)(params, q, qdot, p)
  dH_dq = jax.jacfwd(Hamiltonian, argnums = 1)(params, q, qdot, p)
  dR_dqdot = jax.jacfwd(eval_rayleigh, argnums = 2)(params, q, qdot)
  Qext = jnp.reshape(Qext, ((latent,)))
     
  dq_dt = (dH_dp).reshape((latent,))
  dp_dt = (-dH_dq - dR_dqdot + Qext).reshape((latent,))
        
  return [dq_dt, dp_dt]

def EulerLagrange(params, q, qdot, Qext):
  latent = LATENT
  p = momentum(params, q, qdot)

  [_, dp] = canonical(params, q, qdot, p, Qext)

  dL_dqdot = jax.jacfwd(Lagrangian, argnums = 2)(params, q, qdot)  
  dL_dq = jax.jacfwd(Lagrangian, argnums = 1)(params, q, qdot)

  dR_dqdot = jax.jacfwd(eval_rayleigh, argnums = 2)(params, q, qdot)

  d_dL_dqdot_dt = dp #(tape.gradient(L, q) - tape.gradient(R, qdot) + Fext)

  EL_resid = jnp.reshape(d_dL_dqdot_dt, (latent,)) - jnp.reshape(dL_dq, (latent,)) + jnp.reshape(dR_dqdot, (latent,)) - jnp.reshape(Qext, (latent,))

  #dt_check = F_loss(dL_dqdot, d_dL_dqdot_dt, dt = DT)
        
  return jnp.mean(jnp.square(EL_resid)) #+ dt_check*1.0e-3

def get_Mass_Matrix(params, q, dq):
  latent = LATENT
  M = jax.hessian(eval_kinetic, argnums = 2)(params, q, dq)
  M = jnp.reshape(M, (latent, latent))
  return M #jnp.reshape(M, (self.latent, self.latent))

def get_K_Matrix(params, q, dq):
  latent = LATENT
  M = jax.hessian(eval_potential, argnums = 1)(params, q, dq)
  M = jnp.reshape(M, (latent, latent))
  return M #jnp.reshape(M, (self.latent, self.latent))

def get_C_Matrix(params, q, dq):
  latent = LATENT
  M = jax.hessian(eval_rayleigh, argnums = 2)(params, q, dq)
  M = jnp.reshape(M, (latent, latent))
  return M #jnp.reshape(M, (self.latent, self.latent))

def get_Mass_Matrix_nl(params, q, dq):
  latent = LATENT
  M = jax.hessian(eval_kinetic_nl, argnums = 2)(params, q, dq)
  M = jnp.reshape(M, (latent, latent))
  return M #jnp.reshape(M, (self.latent, self.latent))

def get_K_Matrix_nl(params, q, dq):
  latent = LATENT
  M = jax.hessian(eval_potential_nl, argnums = 1)(params, q, dq)
  M = jnp.reshape(M, (latent, latent))
  return M #jnp.reshape(M, (self.latent, self.latent))

def get_C_Matrix_nl(params, q, dq):
  latent = LATENT
  M = jax.hessian(eval_rayleigh_nl, argnums = 2)(params, q, dq)
  M = jnp.reshape(M, (latent, latent))
  return M #jnp.reshape(M, (self.latent, self.latent))

def get_dM_dq(params, q, dq):
  dM = jax.jacfwd(get_Mass_Matrix, argnums = 1)(params, q, dq)
  return dM

def get_Minv(params, Mmat):
  latent = LATENT
  Minv = jnp.linalg.pinv(jnp.reshape(Mmat, (latent, latent)))
  return Minv


def get_Qnc(params, q, dq, Qext, in_params):
  Qnn_ = eval_Qnn(params, q, dq, in_params)
 # Qnn_vec = jnp.zeros(Qext.shape)
  #x.at[idx].set(y) syntax
  #Qnn_vec.at[NDZ].set(Qnn_)
  Qnc = Qext - Qnn_
  return Qnc

def get_Qnn(params, q, dq, in_params):
  Qnn_ = eval_Qnn(params, q, dq, in_params)
  return Qnn_

def get_Qforce(D_params, q, dq, ddq, Fext):
  dx_dq = jax.jacfwd(decode_q, argnums = 1)(D_params, q, dq, ddq)
  dx_dq = jnp.transpose(dx_dq, (1,0))
  Fext = jnp.reshape(Fext, (ORIGINAL_DIM,))
  Qext = jnp.matmul(dx_dq, Fext)
    
  return jnp.reshape(Qext, (LATENT,))

def get_ddq(params, q, dq, dp):
  latent = LATENT
  Mmat = get_Mass_Matrix(params, q, dq)
  dM = get_dM_dq(params, q, dq)

  Mdq = jnp.matmul(dM, jnp.reshape(dq, (1,latent,1)))

  term2 = jnp.matmul(jnp.reshape(Mdq, (latent,latent)), 
                     jnp.reshape(dq, (latent,1)))

  term1 = jnp.reshape(dp, (latent,))
  term2 = jnp.reshape(term2, (latent,))

  Minv = get_Minv(params, Mmat)

  ddq = jnp.matmul(Minv, jnp.reshape(term1 - term2, (latent, 1)))

  ddq = jnp.reshape(ddq, (latent,))

  return ddq

def general_checks(params, q, dq):

  latent = LATENT 

  Mmat = get_Mass_Matrix(params, q, dq)
  phat = jnp.matmul(jnp.reshape(Mmat, (latent, latent)),
                    jnp.reshape(dq, (latent, 1))
  )
  p = momentum(params, q, dq)
  p_check = jnp.square(jnp.reshape(p, (latent,)) - jnp.reshape(phat, (latent,)))

  ##

  qdotp = jnp.dot(jnp.reshape(dq, (1, latent)), 
                  jnp.reshape(p, (latent, 1))
  )

  That = 1/2.*qdotp 
  T = eval_kinetic(params, q, dq)
  V = eval_potential(params, q, dq)
  T_check = jnp.square(jnp.reshape(T, (-1,)) - jnp.reshape(That, (-1,)))

  ##

  H = Hamiltonian(params, q, dq, p)
  Hhat = T + V
  H_check = jnp.square(H - Hhat)

  ##

  #Minv = get_Minv(params, Mmat)
  #eye_hat = jnp.matmul(jnp.reshape(Mmat, (latent, latent)), 
  #                     jnp.reshape(Minv, (latent, latent))
  #)
  #eye = jnp.eye(latent)

  #Minv_check = jnp.sum(jnp.square(
  #    eye_hat - eye 
  #))

  ##

  Veq = eval_potential(params, jnp.zeros(q.shape), jnp.zeros(dq.shape))

  Veq_check = jnp.square(Veq)

  ##

  neg_M = jnp.mean(jnp.square(jax.nn.relu(jnp.negative(Mmat))))
  neg_T = jnp.mean(jnp.square(jax.nn.relu(jnp.negative(T))))
  neg_V = jnp.mean(jnp.square(jax.nn.relu(jnp.negative(V))))

  neg_checks = neg_M + neg_T + neg_V

  Mnorm = jnp.linalg.norm(Mmat)
  Mnorm_loss = jnp.square(Mnorm.reshape((1,)) - np.linalg.norm(np.array([[1.1,0],[0,1.1]]))/Mnorm)

  ## 
  ix_off =  ~np.eye(Mmat.shape[0],dtype=bool)
  offM = jnp.sum(jnp.square(Mmat[ix_off]))

  ## 
    
  M11 = jnp.square(Mmat[0,0] - 1.0)

  ## 
  #[-0.1,0.5, 0.1, 0]
  #x0 = jnp.array(w0[:LATENT])
  #dx0 = jnp.array(w0[LATENT:])

  #L0 = Lagrangian(params, x0, dx0)
    
  #L0_check = jnp.mean(jnp.square(L0 - L0_analyt))

  return jnp.mean(T_check)*10 + jnp.mean(p_check)*10 + jnp.mean(H_check)*10 #+  L0_check*0.1 #+ neg_checks*1 #+ H_check*0.1 #+ jnp.mean(M11)# + offM #+ Mnorm_loss


def F_loss(p, dp, dt):
  alpha = ALPHA
  beta = BETA
        
  M = MSTEPS
  #dt = DT

  p = jnp.reshape(p, (-1, LATENT))
  dp = jnp.reshape(dp, (-1, LATENT))

  pmax = jnp.max(jnp.array([jnp.max(jnp.abs(p)), 1])) # trying to normalize so amplitudes are better
  p = p/pmax
  dp = dp/pmax
        
  Y = alpha[0]*p[M:, :] + dt*beta[0]*dp[M:, :]
        
  for m in range(1, M+1):
    Y = Y + alpha[m]*p[M-m:-m, :] + dt*beta[m]*dp[M-m:-m,:]
            
  return jnp.mean(jnp.square(Y))


def jnp_cov(x):
  nDof = LATENT
  mean_x = jnp.reshape(jnp.mean(x, axis = 0), (1, nDof))
  mx = jnp.matmul(mean_x.T, mean_x)
  vx = jnp.matmul(x.T, x) #/tf.cast(tf.shape(x)[0], tf.float32)
  cov_xx = vx - mx
  return cov_xx
WARNING:jax._src.lib.xla_bridge:No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)
(4,)
In [20]:
def l2_loss(x, alpha):
    return alpha * (x ** 2).mean()

def l1_loss(x, alpha):
  return alpha * jnp.abs(x).sum()

@jax.jit
def f_diff(params, x, dx, Fext):
  latent = LATENT

  T_params = params['T_layers']
  V_params = params['V_layers']
  R_params = params['R_layers']
  Qnn_params = params['Qnn_layers']

  params_ = [T_params, V_params, R_params, Qnn_params, params['encoderNet_q']]

  #Qnc = jax.vmap(get_Qnc, in_axes = (None, 0, 0, 0))(params_, x, dx, Fext, in_params)
  p = jax.vmap(momentum, in_axes = (None, 0, 0))(params_, x, dx)
  [dq,dp] = jax.vmap(canonical, in_axes = (None, 0, 0, 0, 0))(params_, x, dx, p, Fext)
  ddq = jax.vmap(get_ddq, in_axes = (None, 0, 0, 0))(params_, x, dx, dp)

  X = jnp.concatenate([jnp.reshape(dx, (-1, latent)), 
                       jnp.reshape(ddq, (-1, latent))], axis = -1)

  return X

def check_1(params, batch):
  T_params = params['T_layers']
  V_params = params['V_layers']
  R_params = params['R_layers']

  params_ = [T_params, V_params, R_params]

  q = batch[:, :, 0]
  qdot = batch[:, :, 1]
  ddq = batch[:, :, 2]
  Qext = batch[:, :, 3]

  p = jax.vmap(momentum, in_axes = (None, 0, 0))(params_, q, qdot)

  [dq, dp] = jax.vmap(canonical, in_axes = (None, 0, 0, 0, 0))(params_, q, qdot, p, Qext)

  dL_dqdot = jax.vmap(jax.jacfwd(Lagrangian, argnums = 2), in_axes = (None, 0, 0))(params_, q, qdot)  
  dL_dq = jax.vmap(jax.jacfwd(Lagrangian, argnums = 1), in_axes = (None, 0, 0))(params_, q, qdot)

  d_dL_dqdot_dt = dp #(tape.gradient(L, q) - tape.gradient(R, qdot) + Fext)

  return [p,dp,dL_dqdot,dL_dq,dq]




def encode_trans(params, batch, comp=0):
    E_params = [params['encoderNet_q']]
    D_params = [params['decoderNet_q']]
    X = batch[:, :, 0]
    dX = batch[:, :, 1]
    ddX = batch[:, :, 2]
    Fext = batch[:, :, 3]

    T = jax.vmap(get_T_q, in_axes = (None, 0, 0, 0, None))(E_params, X, dX, ddX, comp)    

    return T

def decode_trans(params, batch,comp=0):
    D_params = [params['decoderNet_q']]
    Q = batch[:, :, 0]
    dQ = batch[:, :, 1]
    ddQ = batch[:, :, 2]

    T = jax.vmap(get_T_x, in_axes = (None, 0, 0, 0, None))(D_params, Q, dQ, ddQ, comp) 

    return T

def encode_predict(params, batch):
    E_params = [params['encoderNet_q']]
    D_params = [params['decoderNet_q']]
    X = batch[:, :, 0]
    dX = batch[:, :, 1]
    ddX = batch[:, :, 2]
    Fext = batch[:, :, 3]

    x,dx,ddx = jax.vmap(encoder, in_axes = (None, 0, 0, 0))(E_params, X, dX, ddX)    
    #Qext = jax.vmap(get_Qforce, in_axes = (None, 0, 0, 0, 0))(D_params, x, dx, ddx, Fext)
    
    Q = jnp.concatenate([jnp.reshape(x, (-1,LATENT,1)),
                        jnp.reshape(dx, (-1,LATENT,1)),
                        jnp.reshape(ddx, (-1,LATENT,1)),
                        jnp.reshape(Fext, (-1,LATENT,1))], axis = -1)
    
    return Q

def decode_predict(params, batch):
    D_params = [params['decoderNet_q']]
    Q = batch[:, :, 0]
    dQ = batch[:, :, 1]
    ddQ = batch[:, :, 2]

    x,dx,ddx = jax.vmap(decoder, in_axes = (None, 0, 0, 0))(D_params, Q, dQ, ddQ)    
    
    X = jnp.concatenate([jnp.reshape(x, (-1,ORIGINAL_DIM,1)),
                        jnp.reshape(dx, (-1,ORIGINAL_DIM,1)),
                        jnp.reshape(ddx, (-1,ORIGINAL_DIM,1))], axis = -1)
    
    return X

@jax.jit
def predict_ddq(params, batch):
  T_params = params['T_layers']
  V_params = params['V_layers']
  R_params = params['R_layers']
  Qnn_params = params['Qnn_layers']

  params_ = [T_params, V_params, R_params, Qnn_params, params['encoderNet_q']]

  x = batch[:, :, 0]
  dx = batch[:, :, 1]
  ddx = batch[:, :, 2]
  Qext = batch[:, :, 3]


  #Qnc = jax.vmap(get_Qnc, in_axes = (None, 0, 0, 0, 0))(params_, x, dx, Qext, in_params)
  p = jax.vmap(momentum, in_axes = (None, 0, 0))(params_, x, dx)
  [_,dp] = jax.vmap(canonical, in_axes = (None, 0, 0, 0, 0))(params_, x, dx, p, Qext)
  ddq = jax.vmap(get_ddq, in_axes = (None, 0, 0, 0))(params_, x, dx, dp)

  return ddq

def predict_Qnc(params, batch, pbatch):
  T_params = params['T_layers']
  V_params = params['V_layers']
  R_params = params['R_layers']
  Qnn_params = params['Qnn_layers']

  params_ = [T_params, V_params, R_params, Qnn_params]

  x = batch[:, :, 0]
  dx = batch[:, :, 1]
  ddx = batch[:, :, 2]
  Qext = batch[:, :, 3]

  in_params = pbatch[:,:]

  Qnc = jax.vmap(get_Qnc, in_axes = (None, 0, 0, 0, 0))(params_, x, dx, Qext, in_params)

  return Qnc

def predict_kinetic(params, batch):
  T_params = params['T_layers']
  V_params = params['V_layers']
  R_params = params['R_layers']
  Qnn_params = params['Qnn_layers']

  params_ = [T_params, V_params, R_params, Qnn_params, params['encoderNet_q']]

  x = batch[:, :, 0]
  dx = batch[:, :, 1]
  ddx = batch[:, :, 2]

  T = jax.vmap(eval_kinetic, in_axes = (None, 0, 0))(params_, x, dx)
  T_lin = jax.vmap(eval_kinetic_lin, in_axes = (None, 0, 0))(params_, x, dx)
  T_nl = jax.vmap(eval_kinetic_nl, in_axes = (None, 0, 0))(params_, x, dx)

  return T,T_lin,T_nl

def predict_potential(params, batch):
  T_params = params['T_layers']
  V_params = params['V_layers']
  R_params = params['R_layers']
  Qnn_params = params['Qnn_layers']

  params_ = [T_params, V_params, R_params, Qnn_params, params['encoderNet_q']]

  x = batch[:, :, 0]
  dx = batch[:, :, 1]
  ddx = batch[:, :, 2]

  V = jax.vmap(eval_potential, in_axes = (None, 0, 0))(params_, x, dx)
  V_lin = jax.vmap(eval_potential_lin, in_axes = (None, 0, 0))(params_, x, dx)
  V_nl = jax.vmap(eval_potential_nl, in_axes = (None, 0, 0))(params_, x, dx)

  return V,V_lin,V_nl

def predict_Hamiltonian(params, batch):
  T_params = params['T_layers']
  V_params = params['V_layers']
  R_params = params['R_layers']
  Qnn_params = params['Qnn_layers']

  params_ = [T_params, V_params, R_params, Qnn_params, params['encoderNet_q']]

  x = batch[:, :, 0]
  dx = batch[:, :, 1]
  ddx = batch[:, :, 2]

  p = jax.vmap(momentum, in_axes = (None, 0, 0))(params_, x, dx)
  H = jax.vmap(Hamiltonian, in_axes = (None, 0, 0, 0))(params_, x, dx, p)
    
  return H

def predict_mass(params, batch):
  T_params = params['T_layers']
  V_params = params['V_layers']
  R_params = params['R_layers']
  Qnn_params = params['Qnn_layers']

  params_ = [T_params, V_params, R_params, Qnn_params, params['encoderNet_q']]

  x = batch[:, :, 0]
  dx = batch[:, :, 1]
  ddx = batch[:, :, 2]

  M = jax.vmap(get_Mass_Matrix, in_axes = (None, 0, 0))(params_, x, dx)
  Minv = jax.vmap(get_Minv, in_axes = (None, 0))(params_, M)

  return M, Minv

def predict_MKC(params, batch):
  T_params = params['T_layers']
  V_params = params['V_layers']
  R_params = params['R_layers']
  Qnn_params = params['Qnn_layers']

  params_ = [T_params, V_params, R_params, Qnn_params, params['encoderNet_q']]

  x = batch[:, :, 0]
  dx = batch[:, :, 1]
  ddx = batch[:, :, 2]

  M = jax.vmap(get_Mass_Matrix, in_axes = (None, 0, 0))(params_, x, dx)
  K = jax.vmap(get_K_Matrix, in_axes = (None, 0, 0))(params_, x, dx)
  C = jax.vmap(get_C_Matrix, in_axes = (None, 0, 0))(params_, x, dx)

  M_nl = jax.vmap(get_Mass_Matrix_nl, in_axes = (None, 0, 0))(params_, x, dx)
  K_nl = jax.vmap(get_K_Matrix_nl, in_axes = (None, 0, 0))(params_, x, dx)
  C_nl = jax.vmap(get_C_Matrix_nl, in_axes = (None, 0, 0))(params_, x, dx)

  return M,K,C,M_nl,K_nl,C_nl

def compute_metrics(loss):
  metrics = {
      'loss': loss
  }
  return metrics


def lasso_loss(params, alpha = 1e-3):
  T_params = params['T_layers']
  V_params = params['V_layers']
  wT = Kinetic().apply({'params': T_params}, method = Kinetic().get_coeffs)
  wV = Potential().apply({'params': V_params}, method = Potential().get_coeffs)
    
  return l1_loss(wT, alpha) + l1_loss(wV, alpha) 

def loss_fn_(params, batch, xbatch):
  T_params = params['T_layers']
  V_params = params['V_layers']
  R_params = params['R_layers']
  Qnn_params = params['Qnn_layers']

  params_ = [T_params, V_params, R_params, Qnn_params, params['encoderNet_q']]
  E_params = [params['encoderNet_q']]
  D_params = [params['decoderNet_q']]

  x_ = batch[:, :, 0]
  dx_ = batch[:, :, 1]
  ddx_ = batch[:, :, 2]
  Qext = batch[:, :, 3]

  X = xbatch[:, :, 0]
  dX = xbatch[:, :, 1]
  ddX = xbatch[:, :, 2]

  print(x_.shape, X.shape)
  x,dx,ddx = jax.vmap(encoder, in_axes = (None, 0, 0, 0))(E_params, x_,dx_,ddx_)
  X_, dX_, ddX_ =jax.vmap(decoder, in_axes = (None, 0, 0, 0))(D_params, x, dx, ddx)

  T_nl = jax.tree_util.tree_leaves(params['decoderNet_q'])[-1].T
  T_lin = VMAT
    
  tangency = jnp.mean(jnp.dot(T_lin.T, T_nl)**2)

  print(x.shape, dx.shape)

  p = jax.vmap(momentum, in_axes = (None, 0, 0))(params_, x, dx)
  H = jax.vmap(Hamiltonian, in_axes = (None, 0, 0, 0))(params_, x, dx, p)
    
  H_std = jnp.mean(jnp.diff(H, axis = 0))
  
  
  [dq,dp] = jax.vmap(canonical, in_axes = (None, 0, 0, 0, 0))(params_, x, dx, p, Qext)
  #EL_loss = jax.vmap(EulerLagrange, in_axes = (None, 0, 0, 0))(params_, x, dq, Qext)
  ddq = jax.vmap(get_ddq, in_axes = (None, 0, 0, 0))(params_, x, dq, dp)
  dL_dqdot = jax.vmap(jax.jacfwd(Lagrangian, argnums = 2), in_axes = (None, 0, 0))(params_, x, dq)

  F_loss_dp = F_loss(p, dp, dt = DT)*1000
  F_loss_dx = F_loss(dq, ddq, dt = DT)*1000  
  F_loss_x = F_loss(x, dq, dt = DT)*1000

  recon_loss_q = jnp.mean(jnp.square((X[:,:] - X_[:,:])/XMAX[0]))
  recon_loss_dq = jnp.mean(jnp.square((dX[:,:] - dX_[:,:])/XMAX[1]))
  recon_loss_ddq = jnp.mean(jnp.square((ddX[:,:] - ddX_[:,:])/XMAX[2]))
    
  recon_loss_ddq_enc = jnp.mean(jnp.square((ddx - ddq)/QMAX[2]))

  recon_loss = (recon_loss_ddq + recon_loss_dq + recon_loss_q) 

  checks_loss = jax.vmap(general_checks, in_axes = (None, 0, 0))(params_, x, dq)

  d_dL_dqdot_dt = dp #(tape.gradient(L, q) - tape.gradient(R, qdot) + Fext)

  EL_check = F_loss(dL_dqdot, d_dL_dqdot_dt, dt = DT) 

  #supervQ = jnp.mean(jnp.square((X - x)/XMAX[0]))*0.0 #+ jnp.mean(jnp.square((X[:,3] - x[:,2])/XMAX[0])) #force the first latent dim to be the same

  loss = jnp.array([F_loss_dp, F_loss_x, F_loss_dx, jnp.mean(checks_loss)*5, jnp.mean(EL_check), recon_loss_ddq_enc*800, 0.1*recon_loss, 0.1*tangency, H_std])

  return loss

def loss_fn(params, batch, xbatch):
  loss = loss_fn_(params, batch, xbatch)
  loss = jnp.sum(loss)

  loss += sum(
        l2_loss(w, alpha=5e-3) 
        for w in jax.tree_leaves(params)
    )

  return loss

@jax.jit
def train_step(state, batch, xbatch):
  grad_fn = jax.value_and_grad(loss_fn, has_aux=False)
  loss, grads = grad_fn(state.params, batch, xbatch)
  state = state.apply_gradients(grads=grads)
  metrics = compute_metrics(loss)
  return state, metrics


def train_epoch(state, train_ds, xtrain_ds, batch_size, epoch, rng):
  """Train for a single epoch."""
  train_ds_size = len(train_ds)
  steps_per_epoch = train_ds_size // batch_size

  train_ds_batched = train_ds.batch(batch_size).prefetch(1)
  train_ds_batched = tfds.as_numpy(train_ds_batched)

  xtrain_ds_batched = xtrain_ds.batch(batch_size).prefetch(1)
  xtrain_ds_batched = tfds.as_numpy(xtrain_ds_batched)

  batch_metrics = []
  for batch, xbatch in zip(train_ds_batched, xtrain_ds_batched):
    state, metrics = train_step(state, batch, xbatch)
    batch_metrics.append(metrics)

  # compute mean of metrics across each batch in epoch.
  batch_metrics_np = jax.device_get(batch_metrics)
  epoch_metrics_np = {
      k: np.mean([metrics[k] for metrics in batch_metrics_np])
      for k in batch_metrics_np[0]}

  if epoch % 100 == 0:
    print('train epoch: %d, loss: %.8f' % (
        epoch, epoch_metrics_np['loss']))

  return state


@jax.jit
def eval_step(params, batch, xbatch):
  loss = loss_fn(params, batch, xbatch)
  return compute_metrics(loss)

def eval_model(params, test_ds, xtest_ds):

  test_ds_batched = test_ds.batch(len(test_ds)).prefetch(1)
  test_ds_batched = tfds.as_numpy(test_ds_batched)

  xtest_ds_batched = xtest_ds.batch(len(xtest_ds)).prefetch(1)
  xtest_ds_batched = tfds.as_numpy(xtest_ds_batched)

  for batch, xbatch in zip(test_ds_batched, xtest_ds_batched):
    metrics = eval_step(params, batch, xbatch)
    metrics = jax.device_get(metrics)
    summary = jax.tree_map(lambda x: x.item(), metrics)

  return summary['loss']
In [21]:
rng = random.PRNGKey(0)
rng, key = random.split(rng)
init_rng = {'params': random.PRNGKey(0), 'dropout': random.PRNGKey(1)}
initial_variables = AEROM().init(rng, QnormTrainData[:1,:,0], QnormTrainData[:1,:,1], QnormTrainData[:1,:,0])

state = train_state.TrainState.create(
    apply_fn = AEROM().apply,
    params = initial_variables['params'],
    tx = optax.adam(
          learning_rate=1e-4,
          b1=0.9,
          b2=0.98,
          eps=1e-7)
)
(6,) (12,) (48,)
In [22]:
#jax.tree_util.tree_leaves(initial_variables['params']['decoderNet_q'])[-1].shape
In [52]:
from flax.training import checkpoints
dir_name = "./chkpts/chkpts_DQuinn_NOQnn_v01_AE_decTenc_UPDdof6_DecIdent_v03_PolyOnly_noDamp_NLManifold_Vq_diffH_upd/"

load = 0
purge = 1

if load:
    test = os.listdir(dir_name)
    chk_latest = checkpoints.latest_checkpoint(dir_name)
    state = checkpoints.restore_checkpoint(chk_latest, state)

    
if purge:
    test = os.listdir(dir_name)
    for item in test:
          os.remove(os.path.join(dir_name, item))
In [53]:
from flax.training import checkpoints
num_epochs = 150000
batch_size = 250

for epoch in range(1, num_epochs + 1):
    # Use a separate PRNG key to permute image data during shuffling
    rng, input_rng = jax.random.split(rng)
    # Run an optimization step over a training batch
    state = train_epoch(state, q_train_dataset, x_train_dataset, batch_size, epoch, input_rng)
    # Evaluate on the test set after each training epoch
    test_loss= eval_model(state.params, q_test_dataset, x_test_dataset)
    if epoch % 50 == 0:
        print(' test epoch: %d, loss: %.8f' % (
                epoch, test_loss))
    if epoch % 50 == 0:
        checkpoints.save_checkpoint(dir_name, state, epoch, keep=3)
 test epoch: 50, loss: 0.00752504
train epoch: 100, loss: 0.00453971
 test epoch: 100, loss: 0.00751823
 test epoch: 150, loss: 0.00751147
train epoch: 200, loss: 0.00453115
 test epoch: 200, loss: 0.00750481
 test epoch: 250, loss: 0.00749812
train epoch: 300, loss: 0.00452283
 test epoch: 300, loss: 0.00749146
 test epoch: 350, loss: 0.00748487
train epoch: 400, loss: 0.00451475
 test epoch: 400, loss: 0.00747829
 test epoch: 450, loss: 0.00747182
train epoch: 500, loss: 0.00450690
 test epoch: 500, loss: 0.00746535
 test epoch: 550, loss: 0.00745895
train epoch: 600, loss: 0.00449927
 test epoch: 600, loss: 0.00745256
 test epoch: 650, loss: 0.00744625
train epoch: 700, loss: 0.00449186
 test epoch: 700, loss: 0.00743996
 test epoch: 750, loss: 0.00743372
train epoch: 800, loss: 0.00448465
 test epoch: 800, loss: 0.00742762
 test epoch: 850, loss: 0.00742148
train epoch: 900, loss: 0.00447763
 test epoch: 900, loss: 0.00741543
 test epoch: 950, loss: 0.00740936
train epoch: 1000, loss: 0.00447081
 test epoch: 1000, loss: 0.00740343
 test epoch: 1050, loss: 0.00739749
train epoch: 1100, loss: 0.00446416
 test epoch: 1100, loss: 0.00739162
 test epoch: 1150, loss: 0.00738579
train epoch: 1200, loss: 0.00445767
 test epoch: 1200, loss: 0.00738002
 test epoch: 1250, loss: 0.00737431
train epoch: 1300, loss: 0.00445136
 test epoch: 1300, loss: 0.00736866
 test epoch: 1350, loss: 0.00736301
train epoch: 1400, loss: 0.00444518
 test epoch: 1400, loss: 0.00735740
 test epoch: 1450, loss: 0.00735185
train epoch: 1500, loss: 0.00443917
 test epoch: 1500, loss: 0.00734639
 test epoch: 1550, loss: 0.00734096
train epoch: 1600, loss: 0.00443328
 test epoch: 1600, loss: 0.00733558
 test epoch: 1650, loss: 0.00733022
train epoch: 1700, loss: 0.00442754
 test epoch: 1700, loss: 0.00732491
 test epoch: 1750, loss: 0.00731967
train epoch: 1800, loss: 0.00442192
 test epoch: 1800, loss: 0.00731445
 test epoch: 1850, loss: 0.00730927
train epoch: 1900, loss: 0.00441642
 test epoch: 1900, loss: 0.00730417
 test epoch: 1950, loss: 0.00729908
train epoch: 2000, loss: 0.00441104
 test epoch: 2000, loss: 0.00729404
 test epoch: 2050, loss: 0.00728904
train epoch: 2100, loss: 0.00440577
 test epoch: 2100, loss: 0.00728410
 test epoch: 2150, loss: 0.00727915
train epoch: 2200, loss: 0.00440060
 test epoch: 2200, loss: 0.00727425
 test epoch: 2250, loss: 0.00726940
train epoch: 2300, loss: 0.00439554
 test epoch: 2300, loss: 0.00726457
 test epoch: 2350, loss: 0.00725980
train epoch: 2400, loss: 0.00439056
 test epoch: 2400, loss: 0.00725510
 test epoch: 2450, loss: 0.00725044
train epoch: 2500, loss: 0.00438569
 test epoch: 2500, loss: 0.00724576
 test epoch: 2550, loss: 0.00724118
train epoch: 2600, loss: 0.00438090
 test epoch: 2600, loss: 0.00723660
 test epoch: 2650, loss: 0.00723202
train epoch: 2700, loss: 0.00437619
 test epoch: 2700, loss: 0.00722749
 test epoch: 2750, loss: 0.00722307
train epoch: 2800, loss: 0.00437157
 test epoch: 2800, loss: 0.00721865
 test epoch: 2850, loss: 0.00721427
train epoch: 2900, loss: 0.00436703
 test epoch: 2900, loss: 0.00720994
 test epoch: 2950, loss: 0.00720561
train epoch: 3000, loss: 0.00436255
 test epoch: 3000, loss: 0.00720134
 test epoch: 3050, loss: 0.00719711
train epoch: 3100, loss: 0.00435815
 test epoch: 3100, loss: 0.00719290
 test epoch: 3150, loss: 0.00718874
train epoch: 3200, loss: 0.00435382
 test epoch: 3200, loss: 0.00718457
 test epoch: 3250, loss: 0.00718047
train epoch: 3300, loss: 0.00434955
 test epoch: 3300, loss: 0.00717644
 test epoch: 3350, loss: 0.00717243
train epoch: 3400, loss: 0.00434536
 test epoch: 3400, loss: 0.00716840
 test epoch: 3450, loss: 0.00716448
train epoch: 3500, loss: 0.00434122
 test epoch: 3500, loss: 0.00716053
 test epoch: 3550, loss: 0.00715660
train epoch: 3600, loss: 0.00433714
 test epoch: 3600, loss: 0.00715275
 test epoch: 3650, loss: 0.00714887
train epoch: 3700, loss: 0.00433313
 test epoch: 3700, loss: 0.00714503
 test epoch: 3750, loss: 0.00714129
train epoch: 3800, loss: 0.00432916
 test epoch: 3800, loss: 0.00713756
 test epoch: 3850, loss: 0.00713386
train epoch: 3900, loss: 0.00432525
 test epoch: 3900, loss: 0.00713020
 test epoch: 3950, loss: 0.00712652
train epoch: 4000, loss: 0.00432140
 test epoch: 4000, loss: 0.00712295
 test epoch: 4050, loss: 0.00711934
train epoch: 4100, loss: 0.00431760
 test epoch: 4100, loss: 0.00711575
 test epoch: 4150, loss: 0.00711223
train epoch: 4200, loss: 0.00431385
 test epoch: 4200, loss: 0.00710871
 test epoch: 4250, loss: 0.00710524
train epoch: 4300, loss: 0.00431015
 test epoch: 4300, loss: 0.00710176
 test epoch: 4350, loss: 0.00709830
train epoch: 4400, loss: 0.00430649
 test epoch: 4400, loss: 0.00709491
 test epoch: 4450, loss: 0.00709150
train epoch: 4500, loss: 0.00430288
 test epoch: 4500, loss: 0.00708817
 test epoch: 4550, loss: 0.00708489
train epoch: 4600, loss: 0.00429932
 test epoch: 4600, loss: 0.00708157
 test epoch: 4650, loss: 0.00707829
train epoch: 4700, loss: 0.00429580
 test epoch: 4700, loss: 0.00707499
 test epoch: 4750, loss: 0.00707176
train epoch: 4800, loss: 0.00429233
 test epoch: 4800, loss: 0.00706857
 test epoch: 4850, loss: 0.00706538
train epoch: 4900, loss: 0.00428889
 test epoch: 4900, loss: 0.00706219
 test epoch: 4950, loss: 0.00705913
train epoch: 5000, loss: 0.00428549
 test epoch: 5000, loss: 0.00705596
 test epoch: 5050, loss: 0.00705287
train epoch: 5100, loss: 0.00428214
 test epoch: 5100, loss: 0.00704984
 test epoch: 5150, loss: 0.00704678
train epoch: 5200, loss: 0.00427883
 test epoch: 5200, loss: 0.00704374
 test epoch: 5250, loss: 0.00704072
train epoch: 5300, loss: 0.00427555
 test epoch: 5300, loss: 0.00703779
 test epoch: 5350, loss: 0.00703484
train epoch: 5400, loss: 0.00427231
 test epoch: 5400, loss: 0.00703192
 test epoch: 5450, loss: 0.00702904
train epoch: 5500, loss: 0.00426911
 test epoch: 5500, loss: 0.00702612
 test epoch: 5550, loss: 0.00702322
train epoch: 5600, loss: 0.00426594
 test epoch: 5600, loss: 0.00702040
 test epoch: 5650, loss: 0.00701757
train epoch: 5700, loss: 0.00426280
 test epoch: 5700, loss: 0.00701477
 test epoch: 5750, loss: 0.00701199
train epoch: 5800, loss: 0.00425970
 test epoch: 5800, loss: 0.00700925
 test epoch: 5850, loss: 0.00700648
train epoch: 5900, loss: 0.00425663
 test epoch: 5900, loss: 0.00700376
 test epoch: 5950, loss: 0.00700099
train epoch: 6000, loss: 0.00425359
 test epoch: 6000, loss: 0.00699837
 test epoch: 6050, loss: 0.00699568
train epoch: 6100, loss: 0.00425059
 test epoch: 6100, loss: 0.00699304
 test epoch: 6150, loss: 0.00699040
train epoch: 6200, loss: 0.00424762
 test epoch: 6200, loss: 0.00698781
 test epoch: 6250, loss: 0.00698514
train epoch: 6300, loss: 0.00424467
 test epoch: 6300, loss: 0.00698259
 test epoch: 6350, loss: 0.00698002
train epoch: 6400, loss: 0.00424175
 test epoch: 6400, loss: 0.00697745
 test epoch: 6450, loss: 0.00697503
train epoch: 6500, loss: 0.00423886
 test epoch: 6500, loss: 0.00697247
 test epoch: 6550, loss: 0.00697001
train epoch: 6600, loss: 0.00423601
 test epoch: 6600, loss: 0.00696755
 test epoch: 6650, loss: 0.00696517
train epoch: 6700, loss: 0.00423318
 test epoch: 6700, loss: 0.00696271
 test epoch: 6750, loss: 0.00696032
train epoch: 6800, loss: 0.00423037
 test epoch: 6800, loss: 0.00695788
 test epoch: 6850, loss: 0.00695544
train epoch: 6900, loss: 0.00422760
 test epoch: 6900, loss: 0.00695308
 test epoch: 6950, loss: 0.00695076
train epoch: 7000, loss: 0.00422485
 test epoch: 7000, loss: 0.00694847
 test epoch: 7050, loss: 0.00694609
train epoch: 7100, loss: 0.00422212
 test epoch: 7100, loss: 0.00694379
 test epoch: 7150, loss: 0.00694152
train epoch: 7200, loss: 0.00421942
 test epoch: 7200, loss: 0.00693923
 test epoch: 7250, loss: 0.00693698
train epoch: 7300, loss: 0.00421674
 test epoch: 7300, loss: 0.00693480
 test epoch: 7350, loss: 0.00693254
train epoch: 7400, loss: 0.00421408
 test epoch: 7400, loss: 0.00693033
 test epoch: 7450, loss: 0.00692813
train epoch: 7500, loss: 0.00421145
 test epoch: 7500, loss: 0.00692597
 test epoch: 7550, loss: 0.00692379
train epoch: 7600, loss: 0.00420884
 test epoch: 7600, loss: 0.00692163
 test epoch: 7650, loss: 0.00691947
train epoch: 7700, loss: 0.00420625
 test epoch: 7700, loss: 0.00691744
 test epoch: 7750, loss: 0.00691534
train epoch: 7800, loss: 0.00420368
 test epoch: 7800, loss: 0.00691315
 test epoch: 7850, loss: 0.00691101
train epoch: 7900, loss: 0.00420113
 test epoch: 7900, loss: 0.00690896
 test epoch: 7950, loss: 0.00690691
train epoch: 8000, loss: 0.00419861
 test epoch: 8000, loss: 0.00690488
 test epoch: 8050, loss: 0.00690289
train epoch: 8100, loss: 0.00419609
 test epoch: 8100, loss: 0.00690083
 test epoch: 8150, loss: 0.00689882
train epoch: 8200, loss: 0.00419360
 test epoch: 8200, loss: 0.00689679
 test epoch: 8250, loss: 0.00689475
train epoch: 8300, loss: 0.00419113
 test epoch: 8300, loss: 0.00689278
 test epoch: 8350, loss: 0.00689082
train epoch: 8400, loss: 0.00418868
 test epoch: 8400, loss: 0.00688889
 test epoch: 8450, loss: 0.00688691
train epoch: 8500, loss: 0.00418624
 test epoch: 8500, loss: 0.00688503
 test epoch: 8550, loss: 0.00688310
train epoch: 8600, loss: 0.00418382
 test epoch: 8600, loss: 0.00688119
 test epoch: 8650, loss: 0.00687932
train epoch: 8700, loss: 0.00418143
 test epoch: 8700, loss: 0.00687747
 test epoch: 8750, loss: 0.00687557
train epoch: 8800, loss: 0.00417903
 test epoch: 8800, loss: 0.00687368
 test epoch: 8850, loss: 0.00687185
train epoch: 8900, loss: 0.00417667
 test epoch: 8900, loss: 0.00686997
 test epoch: 8950, loss: 0.00686816
train epoch: 9000, loss: 0.00417431
 test epoch: 9000, loss: 0.00686634
 test epoch: 9050, loss: 0.00686453
train epoch: 9100, loss: 0.00417197
 test epoch: 9100, loss: 0.00686270
 test epoch: 9150, loss: 0.00686098
train epoch: 9200, loss: 0.00416964
 test epoch: 9200, loss: 0.00685919
 test epoch: 9250, loss: 0.00685747
train epoch: 9300, loss: 0.00416732
 test epoch: 9300, loss: 0.00685566
 test epoch: 9350, loss: 0.00685395
train epoch: 9400, loss: 0.00416501
 test epoch: 9400, loss: 0.00685228
 test epoch: 9450, loss: 0.00685051
train epoch: 9500, loss: 0.00416271
 test epoch: 9500, loss: 0.00684880
 test epoch: 9550, loss: 0.00684711
train epoch: 9600, loss: 0.00416039
 test epoch: 9600, loss: 0.00684537
 test epoch: 9650, loss: 0.00684357
train epoch: 9700, loss: 0.00415816
 test epoch: 9700, loss: 0.00684186
 test epoch: 9750, loss: 0.00684023
train epoch: 9800, loss: 0.00415597
 test epoch: 9800, loss: 0.00683855
 test epoch: 9850, loss: 0.00683695
train epoch: 9900, loss: 0.00415376
 test epoch: 9900, loss: 0.00683527
 test epoch: 9950, loss: 0.00683360
train epoch: 10000, loss: 0.00415158
 test epoch: 10000, loss: 0.00683200
 test epoch: 10050, loss: 0.00683048
train epoch: 10100, loss: 0.00414936
 test epoch: 10100, loss: 0.00682871
 test epoch: 10150, loss: 0.00682703
train epoch: 10200, loss: 0.00414726
 test epoch: 10200, loss: 0.00682542
 test epoch: 10250, loss: 0.00682379
train epoch: 10300, loss: 0.00414517
 test epoch: 10300, loss: 0.00682219
 test epoch: 10350, loss: 0.00682067
train epoch: 10400, loss: 0.00414310
 test epoch: 10400, loss: 0.00681901
 test epoch: 10450, loss: 0.00681753
train epoch: 10500, loss: 0.00414104
 test epoch: 10500, loss: 0.00681594
 test epoch: 10550, loss: 0.00681439
train epoch: 10600, loss: 0.00413899
 test epoch: 10600, loss: 0.00681288
 test epoch: 10650, loss: 0.00681135
train epoch: 10700, loss: 0.00413695
 test epoch: 10700, loss: 0.00680979
 test epoch: 10750, loss: 0.00680827
train epoch: 10800, loss: 0.00413492
 test epoch: 10800, loss: 0.00680668
 test epoch: 10850, loss: 0.00680517
train epoch: 10900, loss: 0.00413291
 test epoch: 10900, loss: 0.00680368
 test epoch: 10950, loss: 0.00680219
train epoch: 11000, loss: 0.00413090
 test epoch: 11000, loss: 0.00680068
 test epoch: 11050, loss: 0.00679917
train epoch: 11100, loss: 0.00412891
 test epoch: 11100, loss: 0.00679764
 test epoch: 11150, loss: 0.00679615
train epoch: 11200, loss: 0.00412693
 test epoch: 11200, loss: 0.00679461
 test epoch: 11250, loss: 0.00679308
train epoch: 11300, loss: 0.00412495
 test epoch: 11300, loss: 0.00679160
 test epoch: 11350, loss: 0.00679012
train epoch: 11400, loss: 0.00412298
 test epoch: 11400, loss: 0.00678862
 test epoch: 11450, loss: 0.00678717
train epoch: 11500, loss: 0.00412102
 test epoch: 11500, loss: 0.00678566
 test epoch: 11550, loss: 0.00678421
train epoch: 11600, loss: 0.00411907
 test epoch: 11600, loss: 0.00678276
 test epoch: 11650, loss: 0.00678126
train epoch: 11700, loss: 0.00411712
 test epoch: 11700, loss: 0.00677977
 test epoch: 11750, loss: 0.00677832
train epoch: 11800, loss: 0.00411518
 test epoch: 11800, loss: 0.00677690
 test epoch: 11850, loss: 0.00677547
train epoch: 11900, loss: 0.00411325
 test epoch: 11900, loss: 0.00677397
 test epoch: 11950, loss: 0.00677249
train epoch: 12000, loss: 0.00411133
 test epoch: 12000, loss: 0.00677106
 test epoch: 12050, loss: 0.00676963
train epoch: 12100, loss: 0.00410941
 test epoch: 12100, loss: 0.00676817
 test epoch: 12150, loss: 0.00676678
train epoch: 12200, loss: 0.00410749
 test epoch: 12200, loss: 0.00676532
 test epoch: 12250, loss: 0.00676388
train epoch: 12300, loss: 0.00410558
 test epoch: 12300, loss: 0.00676249
 test epoch: 12350, loss: 0.00676109
train epoch: 12400, loss: 0.00410369
 test epoch: 12400, loss: 0.00675971
 test epoch: 12450, loss: 0.00675830
train epoch: 12500, loss: 0.00410180
 test epoch: 12500, loss: 0.00675681
 test epoch: 12550, loss: 0.00675545
train epoch: 12600, loss: 0.00409990
 test epoch: 12600, loss: 0.00675406
 test epoch: 12650, loss: 0.00675264
train epoch: 12700, loss: 0.00409802
 test epoch: 12700, loss: 0.00675126
 test epoch: 12750, loss: 0.00674991
train epoch: 12800, loss: 0.00409615
 test epoch: 12800, loss: 0.00674859
 test epoch: 12850, loss: 0.00674716
train epoch: 12900, loss: 0.00409428
 test epoch: 12900, loss: 0.00674578
 test epoch: 12950, loss: 0.00674438
train epoch: 13000, loss: 0.00409241
 test epoch: 13000, loss: 0.00674290
 test epoch: 13050, loss: 0.00674159
train epoch: 13100, loss: 0.00409055
 test epoch: 13100, loss: 0.00674019
 test epoch: 13150, loss: 0.00673876
train epoch: 13200, loss: 0.00408870
 test epoch: 13200, loss: 0.00673737
 test epoch: 13250, loss: 0.00673603
train epoch: 13300, loss: 0.00408685
 test epoch: 13300, loss: 0.00673463
 test epoch: 13350, loss: 0.00673321
train epoch: 13400, loss: 0.00408501
 test epoch: 13400, loss: 0.00673186
 test epoch: 13450, loss: 0.00673051
train epoch: 13500, loss: 0.00408318
 test epoch: 13500, loss: 0.00672912
 test epoch: 13550, loss: 0.00672771
train epoch: 13600, loss: 0.00408134
 test epoch: 13600, loss: 0.00672633
 test epoch: 13650, loss: 0.00672488
train epoch: 13700, loss: 0.00407953
 test epoch: 13700, loss: 0.00672348
 test epoch: 13750, loss: 0.00672211
train epoch: 13800, loss: 0.00407771
 test epoch: 13800, loss: 0.00672078
 test epoch: 13850, loss: 0.00671938
train epoch: 13900, loss: 0.00407589
 test epoch: 13900, loss: 0.00671804
 test epoch: 13950, loss: 0.00671665
train epoch: 14000, loss: 0.00407409
 test epoch: 14000, loss: 0.00671535
 test epoch: 14050, loss: 0.00671404
train epoch: 14100, loss: 0.00407228
 test epoch: 14100, loss: 0.00671267
 test epoch: 14150, loss: 0.00671133
train epoch: 14200, loss: 0.00407050
 test epoch: 14200, loss: 0.00670995
 test epoch: 14250, loss: 0.00670858
train epoch: 14300, loss: 0.00406871
 test epoch: 14300, loss: 0.00670724
 test epoch: 14350, loss: 0.00670592
train epoch: 14400, loss: 0.00406692
 test epoch: 14400, loss: 0.00670458
 test epoch: 14450, loss: 0.00670321
train epoch: 14500, loss: 0.00406515
 test epoch: 14500, loss: 0.00670187
 test epoch: 14550, loss: 0.00670054
train epoch: 14600, loss: 0.00406338
 test epoch: 14600, loss: 0.00669923
 test epoch: 14650, loss: 0.00669791
train epoch: 14700, loss: 0.00406162
 test epoch: 14700, loss: 0.00669648
 test epoch: 14750, loss: 0.00669513
train epoch: 14800, loss: 0.00405986
 test epoch: 14800, loss: 0.00669375
 test epoch: 14850, loss: 0.00669240
train epoch: 14900, loss: 0.00405811
 test epoch: 14900, loss: 0.00669108
 test epoch: 14950, loss: 0.00668968
train epoch: 15000, loss: 0.00405638
 test epoch: 15000, loss: 0.00668840
 test epoch: 15050, loss: 0.00668701
train epoch: 15100, loss: 0.00405465
 test epoch: 15100, loss: 0.00668559
 test epoch: 15150, loss: 0.00668425
train epoch: 15200, loss: 0.00405293
 test epoch: 15200, loss: 0.00668292
 test epoch: 15250, loss: 0.00668158
train epoch: 15300, loss: 0.00405121
 test epoch: 15300, loss: 0.00668016
 test epoch: 15350, loss: 0.00667881
train epoch: 15400, loss: 0.00404951
 test epoch: 15400, loss: 0.00667748
 test epoch: 15450, loss: 0.00667619
train epoch: 15500, loss: 0.00404781
 test epoch: 15500, loss: 0.00667484
 test epoch: 15550, loss: 0.00667350
train epoch: 15600, loss: 0.00404612
 test epoch: 15600, loss: 0.00667210
 test epoch: 15650, loss: 0.00667076
train epoch: 15700, loss: 0.00404443
 test epoch: 15700, loss: 0.00666941
 test epoch: 15750, loss: 0.00666804
train epoch: 15800, loss: 0.00404276
 test epoch: 15800, loss: 0.00666661
 test epoch: 15850, loss: 0.00666532
train epoch: 15900, loss: 0.00404109
 test epoch: 15900, loss: 0.00666395
 test epoch: 15950, loss: 0.00666262
train epoch: 16000, loss: 0.00403944
 test epoch: 16000, loss: 0.00666128
 test epoch: 16050, loss: 0.00665991
train epoch: 16100, loss: 0.00403779
 test epoch: 16100, loss: 0.00665858
 test epoch: 16150, loss: 0.00665713
train epoch: 16200, loss: 0.00403615
 test epoch: 16200, loss: 0.00665582
 test epoch: 16250, loss: 0.00665448
train epoch: 16300, loss: 0.00403452
 test epoch: 16300, loss: 0.00665312
 test epoch: 16350, loss: 0.00665176
train epoch: 16400, loss: 0.00403291
 test epoch: 16400, loss: 0.00665033
 test epoch: 16450, loss: 0.00664893
train epoch: 16500, loss: 0.00403130
 test epoch: 16500, loss: 0.00664752
 test epoch: 16550, loss: 0.00664615
train epoch: 16600, loss: 0.00402970
 test epoch: 16600, loss: 0.00664476
 test epoch: 16650, loss: 0.00664336
train epoch: 16700, loss: 0.00402811
 test epoch: 16700, loss: 0.00664191
 test epoch: 16750, loss: 0.00664053
train epoch: 16800, loss: 0.00402653
 test epoch: 16800, loss: 0.00663915
 test epoch: 16850, loss: 0.00663772
train epoch: 16900, loss: 0.00402495
 test epoch: 16900, loss: 0.00663635
 test epoch: 16950, loss: 0.00663490
train epoch: 17000, loss: 0.00402340
 test epoch: 17000, loss: 0.00663352
 test epoch: 17050, loss: 0.00663215
train epoch: 17100, loss: 0.00402184
 test epoch: 17100, loss: 0.00663073
 test epoch: 17150, loss: 0.00662926
train epoch: 17200, loss: 0.00402030
 test epoch: 17200, loss: 0.00662786
 test epoch: 17250, loss: 0.00662644
train epoch: 17300, loss: 0.00401876
 test epoch: 17300, loss: 0.00662511
 test epoch: 17350, loss: 0.00662369
train epoch: 17400, loss: 0.00401723
 test epoch: 17400, loss: 0.00662222
 test epoch: 17450, loss: 0.00662079
train epoch: 17500, loss: 0.00401573
 test epoch: 17500, loss: 0.00661937
 test epoch: 17550, loss: 0.00661790
train epoch: 17600, loss: 0.00401422
 test epoch: 17600, loss: 0.00661646
 test epoch: 17650, loss: 0.00661507
train epoch: 17700, loss: 0.00401273
 test epoch: 17700, loss: 0.00661357
 test epoch: 17750, loss: 0.00661213
train epoch: 17800, loss: 0.00401124
 test epoch: 17800, loss: 0.00661067
 test epoch: 17850, loss: 0.00660922
train epoch: 17900, loss: 0.00400978
 test epoch: 17900, loss: 0.00660772
 test epoch: 17950, loss: 0.00660620
train epoch: 18000, loss: 0.00400832
 test epoch: 18000, loss: 0.00660475
 test epoch: 18050, loss: 0.00660329
train epoch: 18100, loss: 0.00400686
 test epoch: 18100, loss: 0.00660187
 test epoch: 18150, loss: 0.00660040
train epoch: 18200, loss: 0.00400542
 test epoch: 18200, loss: 0.00659891
 test epoch: 18250, loss: 0.00659742
train epoch: 18300, loss: 0.00400399
 test epoch: 18300, loss: 0.00659599
 test epoch: 18350, loss: 0.00659454
train epoch: 18400, loss: 0.00400257
 test epoch: 18400, loss: 0.00659308
 test epoch: 18450, loss: 0.00659154
train epoch: 18500, loss: 0.00400115
 test epoch: 18500, loss: 0.00659013
 test epoch: 18550, loss: 0.00658864
train epoch: 18600, loss: 0.00399976
 test epoch: 18600, loss: 0.00658713
 test epoch: 18650, loss: 0.00658565
train epoch: 18700, loss: 0.00399837
 test epoch: 18700, loss: 0.00658417
 test epoch: 18750, loss: 0.00658264
train epoch: 18800, loss: 0.00399698
 test epoch: 18800, loss: 0.00658116
 test epoch: 18850, loss: 0.00657969
train epoch: 18900, loss: 0.00399562
 test epoch: 18900, loss: 0.00657821
 test epoch: 18950, loss: 0.00657666
train epoch: 19000, loss: 0.00399427
 test epoch: 19000, loss: 0.00657514
 test epoch: 19050, loss: 0.00657360
train epoch: 19100, loss: 0.00399292
 test epoch: 19100, loss: 0.00657216
 test epoch: 19150, loss: 0.00657065
train epoch: 19200, loss: 0.00399157
 test epoch: 19200, loss: 0.00656909
 test epoch: 19250, loss: 0.00656760
train epoch: 19300, loss: 0.00399024
 test epoch: 19300, loss: 0.00656612
 test epoch: 19350, loss: 0.00656456
train epoch: 19400, loss: 0.00398893
 test epoch: 19400, loss: 0.00656304
 test epoch: 19450, loss: 0.00656157
train epoch: 19500, loss: 0.00398762
 test epoch: 19500, loss: 0.00656005
 test epoch: 19550, loss: 0.00655852
train epoch: 19600, loss: 0.00398632
 test epoch: 19600, loss: 0.00655708
 test epoch: 19650, loss: 0.00655555
train epoch: 19700, loss: 0.00398503
 test epoch: 19700, loss: 0.00655407
 test epoch: 19750, loss: 0.00655256
train epoch: 19800, loss: 0.00398375
 test epoch: 19800, loss: 0.00655106
 test epoch: 19850, loss: 0.00654949
train epoch: 19900, loss: 0.00398247
 test epoch: 19900, loss: 0.00654797
 test epoch: 19950, loss: 0.00654646
train epoch: 20000, loss: 0.00398121
 test epoch: 20000, loss: 0.00654495
 test epoch: 20050, loss: 0.00654346
train epoch: 20100, loss: 0.00397996
 test epoch: 20100, loss: 0.00654193
 test epoch: 20150, loss: 0.00654043
train epoch: 20200, loss: 0.00397872
 test epoch: 20200, loss: 0.00653902
 test epoch: 20250, loss: 0.00653746
train epoch: 20300, loss: 0.00397749
 test epoch: 20300, loss: 0.00653597
 test epoch: 20350, loss: 0.00653442
train epoch: 20400, loss: 0.00397627
 test epoch: 20400, loss: 0.00653287
 test epoch: 20450, loss: 0.00653133
train epoch: 20500, loss: 0.00397506
 test epoch: 20500, loss: 0.00652981
 test epoch: 20550, loss: 0.00652834
train epoch: 20600, loss: 0.00397384
 test epoch: 20600, loss: 0.00652680
 test epoch: 20650, loss: 0.00652535
train epoch: 20700, loss: 0.00397265
 test epoch: 20700, loss: 0.00652383
 test epoch: 20750, loss: 0.00652229
train epoch: 20800, loss: 0.00397146
 test epoch: 20800, loss: 0.00652078
 test epoch: 20850, loss: 0.00651928
train epoch: 20900, loss: 0.00397029
 test epoch: 20900, loss: 0.00651766
 test epoch: 20950, loss: 0.00651613
train epoch: 21000, loss: 0.00396912
 test epoch: 21000, loss: 0.00651464
 test epoch: 21050, loss: 0.00651312
train epoch: 21100, loss: 0.00396796
 test epoch: 21100, loss: 0.00651161
 test epoch: 21150, loss: 0.00651012
train epoch: 21200, loss: 0.00396681
 test epoch: 21200, loss: 0.00650862
 test epoch: 21250, loss: 0.00650706
train epoch: 21300, loss: 0.00396567
 test epoch: 21300, loss: 0.00650552
 test epoch: 21350, loss: 0.00650399
train epoch: 21400, loss: 0.00396452
 test epoch: 21400, loss: 0.00650248
 test epoch: 21450, loss: 0.00650097
train epoch: 21500, loss: 0.00396340
 test epoch: 21500, loss: 0.00649944
 test epoch: 21550, loss: 0.00649800
train epoch: 21600, loss: 0.00396228
 test epoch: 21600, loss: 0.00649644
 test epoch: 21650, loss: 0.00649492
train epoch: 21700, loss: 0.00396116
 test epoch: 21700, loss: 0.00649346
 test epoch: 21750, loss: 0.00649202
train epoch: 21800, loss: 0.00396006
 test epoch: 21800, loss: 0.00649052
 test epoch: 21850, loss: 0.00648908
train epoch: 21900, loss: 0.00395896
 test epoch: 21900, loss: 0.00648758
 test epoch: 21950, loss: 0.00648609
train epoch: 22000, loss: 0.00395787
 test epoch: 22000, loss: 0.00648456
 test epoch: 22050, loss: 0.00648310
train epoch: 22100, loss: 0.00395680
 test epoch: 22100, loss: 0.00648161
 test epoch: 22150, loss: 0.00648012
train epoch: 22200, loss: 0.00395572
 test epoch: 22200, loss: 0.00647864
 test epoch: 22250, loss: 0.00647718
train epoch: 22300, loss: 0.00395464
 test epoch: 22300, loss: 0.00647573
 test epoch: 22350, loss: 0.00647424
train epoch: 22400, loss: 0.00395359
 test epoch: 22400, loss: 0.00647276
 test epoch: 22450, loss: 0.00647126
train epoch: 22500, loss: 0.00395253
 test epoch: 22500, loss: 0.00646980
 test epoch: 22550, loss: 0.00646834
train epoch: 22600, loss: 0.00395150
 test epoch: 22600, loss: 0.00646690
 test epoch: 22650, loss: 0.00646540
train epoch: 22700, loss: 0.00395045
 test epoch: 22700, loss: 0.00646395
 test epoch: 22750, loss: 0.00646248
train epoch: 22800, loss: 0.00394941
 test epoch: 22800, loss: 0.00646105
 test epoch: 22850, loss: 0.00645962
train epoch: 22900, loss: 0.00394838
 test epoch: 22900, loss: 0.00645817
 test epoch: 22950, loss: 0.00645678
train epoch: 23000, loss: 0.00394737
 test epoch: 23000, loss: 0.00645529
 test epoch: 23050, loss: 0.00645387
train epoch: 23100, loss: 0.00394636
 test epoch: 23100, loss: 0.00645247
 test epoch: 23150, loss: 0.00645105
train epoch: 23200, loss: 0.00394535
 test epoch: 23200, loss: 0.00644970
 test epoch: 23250, loss: 0.00644824
train epoch: 23300, loss: 0.00394434
 test epoch: 23300, loss: 0.00644681
 test epoch: 23350, loss: 0.00644538
train epoch: 23400, loss: 0.00394335
 test epoch: 23400, loss: 0.00644393
 test epoch: 23450, loss: 0.00644251
train epoch: 23500, loss: 0.00394236
 test epoch: 23500, loss: 0.00644111
 test epoch: 23550, loss: 0.00643977
train epoch: 23600, loss: 0.00394137
 test epoch: 23600, loss: 0.00643836
 test epoch: 23650, loss: 0.00643696
train epoch: 23700, loss: 0.00394040
 test epoch: 23700, loss: 0.00643552
 test epoch: 23750, loss: 0.00643410
train epoch: 23800, loss: 0.00393942
 test epoch: 23800, loss: 0.00643268
 test epoch: 23850, loss: 0.00643128
train epoch: 23900, loss: 0.00393845
 test epoch: 23900, loss: 0.00642983
 test epoch: 23950, loss: 0.00642846
train epoch: 24000, loss: 0.00393749
 test epoch: 24000, loss: 0.00642707
 test epoch: 24050, loss: 0.00642567
train epoch: 24100, loss: 0.00393654
 test epoch: 24100, loss: 0.00642430
 test epoch: 24150, loss: 0.00642292
train epoch: 24200, loss: 0.00393558
 test epoch: 24200, loss: 0.00642155
 test epoch: 24250, loss: 0.00642012
train epoch: 24300, loss: 0.00393463
 test epoch: 24300, loss: 0.00641874
 test epoch: 24350, loss: 0.00641737
train epoch: 24400, loss: 0.00393369
 test epoch: 24400, loss: 0.00641603
 test epoch: 24450, loss: 0.00641469
train epoch: 24500, loss: 0.00393276
 test epoch: 24500, loss: 0.00641330
 test epoch: 24550, loss: 0.00641198
train epoch: 24600, loss: 0.00393183
 test epoch: 24600, loss: 0.00641065
 test epoch: 24650, loss: 0.00640929
train epoch: 24700, loss: 0.00393090
 test epoch: 24700, loss: 0.00640793
 test epoch: 24750, loss: 0.00640653
train epoch: 24800, loss: 0.00392997
 test epoch: 24800, loss: 0.00640526
 test epoch: 24850, loss: 0.00640394
train epoch: 24900, loss: 0.00392905
 test epoch: 24900, loss: 0.00640262
 test epoch: 24950, loss: 0.00640121
train epoch: 25000, loss: 0.00392815
 test epoch: 25000, loss: 0.00639985
 test epoch: 25050, loss: 0.00639847
train epoch: 25100, loss: 0.00392724
 test epoch: 25100, loss: 0.00639715
 test epoch: 25150, loss: 0.00639584
train epoch: 25200, loss: 0.00392634
 test epoch: 25200, loss: 0.00639453
 test epoch: 25250, loss: 0.00639319
train epoch: 25300, loss: 0.00392544
 test epoch: 25300, loss: 0.00639184
 test epoch: 25350, loss: 0.00639054
train epoch: 25400, loss: 0.00392454
 test epoch: 25400, loss: 0.00638930
 test epoch: 25450, loss: 0.00638797
train epoch: 25500, loss: 0.00392365
 test epoch: 25500, loss: 0.00638665
 test epoch: 25550, loss: 0.00638532
train epoch: 25600, loss: 0.00392277
 test epoch: 25600, loss: 0.00638396
 test epoch: 25650, loss: 0.00638265
train epoch: 25700, loss: 0.00392187
 test epoch: 25700, loss: 0.00638134
 test epoch: 25750, loss: 0.00638004
train epoch: 25800, loss: 0.00392099
 test epoch: 25800, loss: 0.00637873
 test epoch: 25850, loss: 0.00637742
train epoch: 25900, loss: 0.00392012
 test epoch: 25900, loss: 0.00637611
 test epoch: 25950, loss: 0.00637480
train epoch: 26000, loss: 0.00391924
 test epoch: 26000, loss: 0.00637352
 test epoch: 26050, loss: 0.00637223
train epoch: 26100, loss: 0.00391837
 test epoch: 26100, loss: 0.00637090
 test epoch: 26150, loss: 0.00636958
train epoch: 26200, loss: 0.00391751
 test epoch: 26200, loss: 0.00636832
 test epoch: 26250, loss: 0.00636706
train epoch: 26300, loss: 0.00391664
 test epoch: 26300, loss: 0.00636578
 test epoch: 26350, loss: 0.00636453
train epoch: 26400, loss: 0.00391579
 test epoch: 26400, loss: 0.00636323
 test epoch: 26450, loss: 0.00636188
train epoch: 26500, loss: 0.00391493
 test epoch: 26500, loss: 0.00636055
 test epoch: 26550, loss: 0.00635924
train epoch: 26600, loss: 0.00391408
 test epoch: 26600, loss: 0.00635798
 test epoch: 26650, loss: 0.00635665
train epoch: 26700, loss: 0.00391323
 test epoch: 26700, loss: 0.00635536
 test epoch: 26750, loss: 0.00635403
train epoch: 26800, loss: 0.00391238
 test epoch: 26800, loss: 0.00635276
 test epoch: 26850, loss: 0.00635149
train epoch: 26900, loss: 0.00391154
 test epoch: 26900, loss: 0.00635020
 test epoch: 26950, loss: 0.00634899
train epoch: 27000, loss: 0.00391069
 test epoch: 27000, loss: 0.00634773
 test epoch: 27050, loss: 0.00634642
train epoch: 27100, loss: 0.00390986
 test epoch: 27100, loss: 0.00634513
 test epoch: 27150, loss: 0.00634380
train epoch: 27200, loss: 0.00390902
 test epoch: 27200, loss: 0.00634253
 test epoch: 27250, loss: 0.00634127
train epoch: 27300, loss: 0.00390819
 test epoch: 27300, loss: 0.00633995
 test epoch: 27350, loss: 0.00633868
train epoch: 27400, loss: 0.00390736
 test epoch: 27400, loss: 0.00633738
 test epoch: 27450, loss: 0.00633611
train epoch: 27500, loss: 0.00390653
 test epoch: 27500, loss: 0.00633484
 test epoch: 27550, loss: 0.00633350
train epoch: 27600, loss: 0.00390569
 test epoch: 27600, loss: 0.00633222
 test epoch: 27650, loss: 0.00633096
train epoch: 27700, loss: 0.00390487
 test epoch: 27700, loss: 0.00632969
 test epoch: 27750, loss: 0.00632838
train epoch: 27800, loss: 0.00390405
 test epoch: 27800, loss: 0.00632709
 test epoch: 27850, loss: 0.00632574
train epoch: 27900, loss: 0.00390324
 test epoch: 27900, loss: 0.00632447
 test epoch: 27950, loss: 0.00632317
train epoch: 28000, loss: 0.00390242
 test epoch: 28000, loss: 0.00632191
 test epoch: 28050, loss: 0.00632063
train epoch: 28100, loss: 0.00390162
 test epoch: 28100, loss: 0.00631933
 test epoch: 28150, loss: 0.00631804
train epoch: 28200, loss: 0.00390079
 test epoch: 28200, loss: 0.00631671
 test epoch: 28250, loss: 0.00631551
train epoch: 28300, loss: 0.00389998
 test epoch: 28300, loss: 0.00631422
 test epoch: 28350, loss: 0.00631292
train epoch: 28400, loss: 0.00389918
 test epoch: 28400, loss: 0.00631160
 test epoch: 28450, loss: 0.00631033
train epoch: 28500, loss: 0.00389838
 test epoch: 28500, loss: 0.00630906
 test epoch: 28550, loss: 0.00630772
train epoch: 28600, loss: 0.00389758
 test epoch: 28600, loss: 0.00630640
 test epoch: 28650, loss: 0.00630512
train epoch: 28700, loss: 0.00389677
 test epoch: 28700, loss: 0.00630382
 test epoch: 28750, loss: 0.00630261
train epoch: 28800, loss: 0.00389596
 test epoch: 28800, loss: 0.00630131
 test epoch: 28850, loss: 0.00630003
train epoch: 28900, loss: 0.00389517
 test epoch: 28900, loss: 0.00629871
 test epoch: 28950, loss: 0.00629741
train epoch: 29000, loss: 0.00389437
 test epoch: 29000, loss: 0.00629616
 test epoch: 29050, loss: 0.00629481
train epoch: 29100, loss: 0.00389358
 test epoch: 29100, loss: 0.00629349
 test epoch: 29150, loss: 0.00629222
train epoch: 29200, loss: 0.00389278
 test epoch: 29200, loss: 0.00629089
 test epoch: 29250, loss: 0.00628963
train epoch: 29300, loss: 0.00389199
 test epoch: 29300, loss: 0.00628832
 test epoch: 29350, loss: 0.00628710
train epoch: 29400, loss: 0.00389119
 test epoch: 29400, loss: 0.00628584
 test epoch: 29450, loss: 0.00628456
train epoch: 29500, loss: 0.00389041
 test epoch: 29500, loss: 0.00628316
 test epoch: 29550, loss: 0.00628189
train epoch: 29600, loss: 0.00388962
 test epoch: 29600, loss: 0.00628056
 test epoch: 29650, loss: 0.00627917
train epoch: 29700, loss: 0.00388884
 test epoch: 29700, loss: 0.00627787
 test epoch: 29750, loss: 0.00627660
train epoch: 29800, loss: 0.00388804
 test epoch: 29800, loss: 0.00627535
 test epoch: 29850, loss: 0.00627405
train epoch: 29900, loss: 0.00388726
 test epoch: 29900, loss: 0.00627277
 test epoch: 29950, loss: 0.00627147
train epoch: 30000, loss: 0.00388647
 test epoch: 30000, loss: 0.00627020
 test epoch: 30050, loss: 0.00626885
train epoch: 30100, loss: 0.00388568
 test epoch: 30100, loss: 0.00626751
 test epoch: 30150, loss: 0.00626627
train epoch: 30200, loss: 0.00388490
 test epoch: 30200, loss: 0.00626495
 test epoch: 30250, loss: 0.00626368
train epoch: 30300, loss: 0.00388411
 test epoch: 30300, loss: 0.00626239
 test epoch: 30350, loss: 0.00626107
train epoch: 30400, loss: 0.00388333
 test epoch: 30400, loss: 0.00625979
 test epoch: 30450, loss: 0.00625844
train epoch: 30500, loss: 0.00388255
 test epoch: 30500, loss: 0.00625711
 test epoch: 30550, loss: 0.00625594
train epoch: 30600, loss: 0.00388177
 test epoch: 30600, loss: 0.00625466
 test epoch: 30650, loss: 0.00625329
train epoch: 30700, loss: 0.00388099
 test epoch: 30700, loss: 0.00625203
 test epoch: 30750, loss: 0.00625070
train epoch: 30800, loss: 0.00388021
 test epoch: 30800, loss: 0.00624938
 test epoch: 30850, loss: 0.00624813
train epoch: 30900, loss: 0.00387943
 test epoch: 30900, loss: 0.00624685
 test epoch: 30950, loss: 0.00624560
train epoch: 31000, loss: 0.00387865
 test epoch: 31000, loss: 0.00624433
 test epoch: 31050, loss: 0.00624306
train epoch: 31100, loss: 0.00387787
 test epoch: 31100, loss: 0.00624175
 test epoch: 31150, loss: 0.00624049
train epoch: 31200, loss: 0.00387708
 test epoch: 31200, loss: 0.00623916
 test epoch: 31250, loss: 0.00623784
train epoch: 31300, loss: 0.00387631
 test epoch: 31300, loss: 0.00623656
 test epoch: 31350, loss: 0.00623524
train epoch: 31400, loss: 0.00387554
 test epoch: 31400, loss: 0.00623395
 test epoch: 31450, loss: 0.00623264
train epoch: 31500, loss: 0.00387475
 test epoch: 31500, loss: 0.00623142
 test epoch: 31550, loss: 0.00623016
train epoch: 31600, loss: 0.00387398
 test epoch: 31600, loss: 0.00622884
 test epoch: 31650, loss: 0.00622755
train epoch: 31700, loss: 0.00387321
 test epoch: 31700, loss: 0.00622633
 test epoch: 31750, loss: 0.00622504
train epoch: 31800, loss: 0.00387243
 test epoch: 31800, loss: 0.00622378
 test epoch: 31850, loss: 0.00622247
train epoch: 31900, loss: 0.00387165
 test epoch: 31900, loss: 0.00622115
 test epoch: 31950, loss: 0.00621992
train epoch: 32000, loss: 0.00387085
 test epoch: 32000, loss: 0.00621869
 test epoch: 32050, loss: 0.00621745
train epoch: 32100, loss: 0.00387007
 test epoch: 32100, loss: 0.00621615
 test epoch: 32150, loss: 0.00621490
train epoch: 32200, loss: 0.00386928
 test epoch: 32200, loss: 0.00621360
 test epoch: 32250, loss: 0.00621233
train epoch: 32300, loss: 0.00386850
 test epoch: 32300, loss: 0.00621100
 test epoch: 32350, loss: 0.00620973
train epoch: 32400, loss: 0.00386771
 test epoch: 32400, loss: 0.00620850
 test epoch: 32450, loss: 0.00620720
train epoch: 32500, loss: 0.00386692
 test epoch: 32500, loss: 0.00620595
 test epoch: 32550, loss: 0.00620469
train epoch: 32600, loss: 0.00386614
 test epoch: 32600, loss: 0.00620342
 test epoch: 32650, loss: 0.00620212
train epoch: 32700, loss: 0.00386534
 test epoch: 32700, loss: 0.00620090
 test epoch: 32750, loss: 0.00619961
train epoch: 32800, loss: 0.00386456
 test epoch: 32800, loss: 0.00619834
 test epoch: 32850, loss: 0.00619707
train epoch: 32900, loss: 0.00386376
 test epoch: 32900, loss: 0.00619582
 test epoch: 32950, loss: 0.00619454
train epoch: 33000, loss: 0.00386297
 test epoch: 33000, loss: 0.00619327
 test epoch: 33050, loss: 0.00619208
train epoch: 33100, loss: 0.00386217
 test epoch: 33100, loss: 0.00619089
 test epoch: 33150, loss: 0.00618961
train epoch: 33200, loss: 0.00386138
 test epoch: 33200, loss: 0.00618839
 test epoch: 33250, loss: 0.00618719
train epoch: 33300, loss: 0.00386057
 test epoch: 33300, loss: 0.00618599
 test epoch: 33350, loss: 0.00618476
train epoch: 33400, loss: 0.00385978
 test epoch: 33400, loss: 0.00618352
 test epoch: 33450, loss: 0.00618227
train epoch: 33500, loss: 0.00385897
 test epoch: 33500, loss: 0.00618103
 test epoch: 33550, loss: 0.00617977
train epoch: 33600, loss: 0.00385817
 test epoch: 33600, loss: 0.00617853
 test epoch: 33650, loss: 0.00617732
train epoch: 33700, loss: 0.00385736
 test epoch: 33700, loss: 0.00617605
 test epoch: 33750, loss: 0.00617486
train epoch: 33800, loss: 0.00385655
 test epoch: 33800, loss: 0.00617375
 test epoch: 33850, loss: 0.00617252
train epoch: 33900, loss: 0.00385574
 test epoch: 33900, loss: 0.00617123
 test epoch: 33950, loss: 0.00617006
train epoch: 34000, loss: 0.00385493
 test epoch: 34000, loss: 0.00616884
 test epoch: 34050, loss: 0.00616768
train epoch: 34100, loss: 0.00385411
 test epoch: 34100, loss: 0.00616649
 test epoch: 34150, loss: 0.00616531
train epoch: 34200, loss: 0.00385329
 test epoch: 34200, loss: 0.00616413
 test epoch: 34250, loss: 0.00616291
train epoch: 34300, loss: 0.00385246
 test epoch: 34300, loss: 0.00616176
 test epoch: 34350, loss: 0.00616060
train epoch: 34400, loss: 0.00385164
 test epoch: 34400, loss: 0.00615941
 test epoch: 34450, loss: 0.00615821
train epoch: 34500, loss: 0.00385082
 test epoch: 34500, loss: 0.00615700
 test epoch: 34550, loss: 0.00615582
train epoch: 34600, loss: 0.00385000
 test epoch: 34600, loss: 0.00615465
 test epoch: 34650, loss: 0.00615342
train epoch: 34700, loss: 0.00384916
 test epoch: 34700, loss: 0.00615228
 test epoch: 34750, loss: 0.00615112
train epoch: 34800, loss: 0.00384834
 test epoch: 34800, loss: 0.00614996
 test epoch: 34850, loss: 0.00614885
train epoch: 34900, loss: 0.00384750
 test epoch: 34900, loss: 0.00614766
 test epoch: 34950, loss: 0.00614650
train epoch: 35000, loss: 0.00384667
 test epoch: 35000, loss: 0.00614533
 test epoch: 35050, loss: 0.00614413
train epoch: 35100, loss: 0.00384583
 test epoch: 35100, loss: 0.00614294
 test epoch: 35150, loss: 0.00614180
train epoch: 35200, loss: 0.00384500
 test epoch: 35200, loss: 0.00614065
 test epoch: 35250, loss: 0.00613947
train epoch: 35300, loss: 0.00384415
 test epoch: 35300, loss: 0.00613829
 test epoch: 35350, loss: 0.00613716
train epoch: 35400, loss: 0.00384331
 test epoch: 35400, loss: 0.00613595
 test epoch: 35450, loss: 0.00613488
train epoch: 35500, loss: 0.00384248
 test epoch: 35500, loss: 0.00613371
 test epoch: 35550, loss: 0.00613256
train epoch: 35600, loss: 0.00384162
 test epoch: 35600, loss: 0.00613141
 test epoch: 35650, loss: 0.00613024
train epoch: 35700, loss: 0.00384078
 test epoch: 35700, loss: 0.00612905
 test epoch: 35750, loss: 0.00612790
train epoch: 35800, loss: 0.00383994
 test epoch: 35800, loss: 0.00612674
 test epoch: 35850, loss: 0.00612559
train epoch: 35900, loss: 0.00383908
 test epoch: 35900, loss: 0.00612439
 test epoch: 35950, loss: 0.00612325
train epoch: 36000, loss: 0.00383824
 test epoch: 36000, loss: 0.00612215
 test epoch: 36050, loss: 0.00612093
train epoch: 36100, loss: 0.00383738
 test epoch: 36100, loss: 0.00611974
 test epoch: 36150, loss: 0.00611860
train epoch: 36200, loss: 0.00383654
 test epoch: 36200, loss: 0.00611740
 test epoch: 36250, loss: 0.00611623
train epoch: 36300, loss: 0.00383568
 test epoch: 36300, loss: 0.00611512
 test epoch: 36350, loss: 0.00611385
train epoch: 36400, loss: 0.00383483
 test epoch: 36400, loss: 0.00611264
 test epoch: 36450, loss: 0.00611145
train epoch: 36500, loss: 0.00383399
 test epoch: 36500, loss: 0.00611024
 test epoch: 36550, loss: 0.00610907
train epoch: 36600, loss: 0.00383314
 test epoch: 36600, loss: 0.00610785
 test epoch: 36650, loss: 0.00610667
train epoch: 36700, loss: 0.00383230
 test epoch: 36700, loss: 0.00610546
 test epoch: 36750, loss: 0.00610423
train epoch: 36800, loss: 0.00383145
 test epoch: 36800, loss: 0.00610296
 test epoch: 36850, loss: 0.00610174
train epoch: 36900, loss: 0.00383061
 test epoch: 36900, loss: 0.00610050
 test epoch: 36950, loss: 0.00609925
train epoch: 37000, loss: 0.00382976
 test epoch: 37000, loss: 0.00609805
 test epoch: 37050, loss: 0.00609678
train epoch: 37100, loss: 0.00382892
 test epoch: 37100, loss: 0.00609551
 test epoch: 37150, loss: 0.00609434
train epoch: 37200, loss: 0.00382808
 test epoch: 37200, loss: 0.00609309
 test epoch: 37250, loss: 0.00609172
train epoch: 37300, loss: 0.00382725
 test epoch: 37300, loss: 0.00609040
 test epoch: 37350, loss: 0.00608910
train epoch: 37400, loss: 0.00382641
 test epoch: 37400, loss: 0.00608787
 test epoch: 37450, loss: 0.00608653
train epoch: 37500, loss: 0.00382559
 test epoch: 37500, loss: 0.00608526
 test epoch: 37550, loss: 0.00608399
train epoch: 37600, loss: 0.00382475
 test epoch: 37600, loss: 0.00608266
 test epoch: 37650, loss: 0.00608135
train epoch: 37700, loss: 0.00382393
 test epoch: 37700, loss: 0.00608002
 test epoch: 37750, loss: 0.00607867
train epoch: 37800, loss: 0.00382310
 test epoch: 37800, loss: 0.00607734
 test epoch: 37850, loss: 0.00607595
train epoch: 37900, loss: 0.00382229
 test epoch: 37900, loss: 0.00607457
 test epoch: 37950, loss: 0.00607320
train epoch: 38000, loss: 0.00382147
 test epoch: 38000, loss: 0.00607187
 test epoch: 38050, loss: 0.00607046
train epoch: 38100, loss: 0.00382065
 test epoch: 38100, loss: 0.00606909
 test epoch: 38150, loss: 0.00606774
train epoch: 38200, loss: 0.00381985
 test epoch: 38200, loss: 0.00606635
 test epoch: 38250, loss: 0.00606495
train epoch: 38300, loss: 0.00381903
 test epoch: 38300, loss: 0.00606352
 test epoch: 38350, loss: 0.00606215
train epoch: 38400, loss: 0.00381823
 test epoch: 38400, loss: 0.00606071
 test epoch: 38450, loss: 0.00605922
train epoch: 38500, loss: 0.00381743
 test epoch: 38500, loss: 0.00605782
 test epoch: 38550, loss: 0.00605646
train epoch: 38600, loss: 0.00381663
 test epoch: 38600, loss: 0.00605502
 test epoch: 38650, loss: 0.00605357
train epoch: 38700, loss: 0.00381584
 test epoch: 38700, loss: 0.00605208
 test epoch: 38750, loss: 0.00605067
train epoch: 38800, loss: 0.00381504
 test epoch: 38800, loss: 0.00604926
 test epoch: 38850, loss: 0.00604784
train epoch: 38900, loss: 0.00381425
 test epoch: 38900, loss: 0.00604634
 test epoch: 38950, loss: 0.00604485
train epoch: 39000, loss: 0.00381347
 test epoch: 39000, loss: 0.00604338
 test epoch: 39050, loss: 0.00604192
train epoch: 39100, loss: 0.00381269
 test epoch: 39100, loss: 0.00604042
 test epoch: 39150, loss: 0.00603894
train epoch: 39200, loss: 0.00381191
 test epoch: 39200, loss: 0.00603742
 test epoch: 39250, loss: 0.00603588
train epoch: 39300, loss: 0.00381113
 test epoch: 39300, loss: 0.00603443
 test epoch: 39350, loss: 0.00603299
train epoch: 39400, loss: 0.00381036
 test epoch: 39400, loss: 0.00603153
 test epoch: 39450, loss: 0.00603000
train epoch: 39500, loss: 0.00380959
 test epoch: 39500, loss: 0.00602857
 test epoch: 39550, loss: 0.00602706
train epoch: 39600, loss: 0.00380882
 test epoch: 39600, loss: 0.00602557
 test epoch: 39650, loss: 0.00602408
train epoch: 39700, loss: 0.00380806
 test epoch: 39700, loss: 0.00602255
 test epoch: 39750, loss: 0.00602101
train epoch: 39800, loss: 0.00380729
 test epoch: 39800, loss: 0.00601947
 test epoch: 39850, loss: 0.00601801
train epoch: 39900, loss: 0.00380653
 test epoch: 39900, loss: 0.00601659
 test epoch: 39950, loss: 0.00601506
train epoch: 40000, loss: 0.00380577
 test epoch: 40000, loss: 0.00601356
 test epoch: 40050, loss: 0.00601207
train epoch: 40100, loss: 0.00380501
 test epoch: 40100, loss: 0.00601052
 test epoch: 40150, loss: 0.00600901
train epoch: 40200, loss: 0.00380427
 test epoch: 40200, loss: 0.00600747
 test epoch: 40250, loss: 0.00600600
train epoch: 40300, loss: 0.00380351
 test epoch: 40300, loss: 0.00600449
 test epoch: 40350, loss: 0.00600302
train epoch: 40400, loss: 0.00380276
 test epoch: 40400, loss: 0.00600147
 test epoch: 40450, loss: 0.00599998
train epoch: 40500, loss: 0.00380202
 test epoch: 40500, loss: 0.00599854
 test epoch: 40550, loss: 0.00599700
train epoch: 40600, loss: 0.00380128
 test epoch: 40600, loss: 0.00599544
 test epoch: 40650, loss: 0.00599391
train epoch: 40700, loss: 0.00380052
 test epoch: 40700, loss: 0.00599243
 test epoch: 40750, loss: 0.00599101
train epoch: 40800, loss: 0.00379979
 test epoch: 40800, loss: 0.00598951
 test epoch: 40850, loss: 0.00598797
train epoch: 40900, loss: 0.00379904
 test epoch: 40900, loss: 0.00598651
 test epoch: 40950, loss: 0.00598506
train epoch: 41000, loss: 0.00379830
 test epoch: 41000, loss: 0.00598353
 test epoch: 41050, loss: 0.00598202
train epoch: 41100, loss: 0.00379757
 test epoch: 41100, loss: 0.00598054
 test epoch: 41150, loss: 0.00597906
train epoch: 41200, loss: 0.00379683
 test epoch: 41200, loss: 0.00597761
 test epoch: 41250, loss: 0.00597605
train epoch: 41300, loss: 0.00379609
 test epoch: 41300, loss: 0.00597457
 test epoch: 41350, loss: 0.00597316
train epoch: 41400, loss: 0.00379536
 test epoch: 41400, loss: 0.00597168
 test epoch: 41450, loss: 0.00597015
train epoch: 41500, loss: 0.00379462
 test epoch: 41500, loss: 0.00596865
 test epoch: 41550, loss: 0.00596722
train epoch: 41600, loss: 0.00379390
 test epoch: 41600, loss: 0.00596577
 test epoch: 41650, loss: 0.00596432
train epoch: 41700, loss: 0.00379316
 test epoch: 41700, loss: 0.00596286
 test epoch: 41750, loss: 0.00596145
train epoch: 41800, loss: 0.00379243
 test epoch: 41800, loss: 0.00595999
 test epoch: 41850, loss: 0.00595852
train epoch: 41900, loss: 0.00379169
 test epoch: 41900, loss: 0.00595704
 test epoch: 41950, loss: 0.00595562
train epoch: 42000, loss: 0.00379097
 test epoch: 42000, loss: 0.00595417
 test epoch: 42050, loss: 0.00595273
train epoch: 42100, loss: 0.00379025
 test epoch: 42100, loss: 0.00595128
 test epoch: 42150, loss: 0.00594986
train epoch: 42200, loss: 0.00378952
 test epoch: 42200, loss: 0.00594847
 test epoch: 42250, loss: 0.00594705
train epoch: 42300, loss: 0.00378879
 test epoch: 42300, loss: 0.00594562
 test epoch: 42350, loss: 0.00594423
train epoch: 42400, loss: 0.00378806
 test epoch: 42400, loss: 0.00594282
 test epoch: 42450, loss: 0.00594145
train epoch: 42500, loss: 0.00378734
 test epoch: 42500, loss: 0.00594004
 test epoch: 42550, loss: 0.00593866
train epoch: 42600, loss: 0.00378662
 test epoch: 42600, loss: 0.00593726
 test epoch: 42650, loss: 0.00593590
train epoch: 42700, loss: 0.00378589
 test epoch: 42700, loss: 0.00593449
 test epoch: 42750, loss: 0.00593311
train epoch: 42800, loss: 0.00378517
 test epoch: 42800, loss: 0.00593177
 test epoch: 42850, loss: 0.00593036
train epoch: 42900, loss: 0.00378445
 test epoch: 42900, loss: 0.00592898
 test epoch: 42950, loss: 0.00592765
train epoch: 43000, loss: 0.00378371
 test epoch: 43000, loss: 0.00592635
 test epoch: 43050, loss: 0.00592496
train epoch: 43100, loss: 0.00378300
 test epoch: 43100, loss: 0.00592356
 test epoch: 43150, loss: 0.00592221
train epoch: 43200, loss: 0.00378227
 test epoch: 43200, loss: 0.00592095
 test epoch: 43250, loss: 0.00591959
train epoch: 43300, loss: 0.00378155
 test epoch: 43300, loss: 0.00591823
 test epoch: 43350, loss: 0.00591691
train epoch: 43400, loss: 0.00378084
 test epoch: 43400, loss: 0.00591555
 test epoch: 43450, loss: 0.00591424
train epoch: 43500, loss: 0.00378012
 test epoch: 43500, loss: 0.00591289
 test epoch: 43550, loss: 0.00591155
train epoch: 43600, loss: 0.00377941
 test epoch: 43600, loss: 0.00591024
 test epoch: 43650, loss: 0.00590895
train epoch: 43700, loss: 0.00377868
 test epoch: 43700, loss: 0.00590767
 test epoch: 43750, loss: 0.00590635
train epoch: 43800, loss: 0.00377797
 test epoch: 43800, loss: 0.00590503
 test epoch: 43850, loss: 0.00590379
train epoch: 43900, loss: 0.00377725
 test epoch: 43900, loss: 0.00590248
 test epoch: 43950, loss: 0.00590119
train epoch: 44000, loss: 0.00377653
 test epoch: 44000, loss: 0.00589986
 test epoch: 44050, loss: 0.00589869
train epoch: 44100, loss: 0.00377582
 test epoch: 44100, loss: 0.00589742
 test epoch: 44150, loss: 0.00589618
train epoch: 44200, loss: 0.00377511
 test epoch: 44200, loss: 0.00589488
 test epoch: 44250, loss: 0.00589365
train epoch: 44300, loss: 0.00377440
 test epoch: 44300, loss: 0.00589232
 test epoch: 44350, loss: 0.00589107
train epoch: 44400, loss: 0.00377369
 test epoch: 44400, loss: 0.00588979
 test epoch: 44450, loss: 0.00588853
train epoch: 44500, loss: 0.00377298
 test epoch: 44500, loss: 0.00588731
 test epoch: 44550, loss: 0.00588601
train epoch: 44600, loss: 0.00377227
 test epoch: 44600, loss: 0.00588479
 test epoch: 44650, loss: 0.00588361
train epoch: 44700, loss: 0.00377156
 test epoch: 44700, loss: 0.00588234
 test epoch: 44750, loss: 0.00588111
train epoch: 44800, loss: 0.00377085
 test epoch: 44800, loss: 0.00587989
 test epoch: 44850, loss: 0.00587867
train epoch: 44900, loss: 0.00377015
 test epoch: 44900, loss: 0.00587742
 test epoch: 44950, loss: 0.00587622
train epoch: 45000, loss: 0.00376944
 test epoch: 45000, loss: 0.00587501
 test epoch: 45050, loss: 0.00587386
train epoch: 45100, loss: 0.00376873
 test epoch: 45100, loss: 0.00587266
 test epoch: 45150, loss: 0.00587145
train epoch: 45200, loss: 0.00376802
 test epoch: 45200, loss: 0.00587024
 test epoch: 45250, loss: 0.00586900
train epoch: 45300, loss: 0.00376733
 test epoch: 45300, loss: 0.00586779
 test epoch: 45350, loss: 0.00586662
train epoch: 45400, loss: 0.00376662
 test epoch: 45400, loss: 0.00586541
 test epoch: 45450, loss: 0.00586420
train epoch: 45500, loss: 0.00376592
 test epoch: 45500, loss: 0.00586300
 test epoch: 45550, loss: 0.00586179
train epoch: 45600, loss: 0.00376523
 test epoch: 45600, loss: 0.00586064
 test epoch: 45650, loss: 0.00585946
train epoch: 45700, loss: 0.00376452
 test epoch: 45700, loss: 0.00585826
 test epoch: 45750, loss: 0.00585710
train epoch: 45800, loss: 0.00376383
 test epoch: 45800, loss: 0.00585599
 test epoch: 45850, loss: 0.00585480
train epoch: 45900, loss: 0.00376314
 test epoch: 45900, loss: 0.00585360
 test epoch: 45950, loss: 0.00585237
train epoch: 46000, loss: 0.00376245
 test epoch: 46000, loss: 0.00585119
 test epoch: 46050, loss: 0.00584997
train epoch: 46100, loss: 0.00376175
 test epoch: 46100, loss: 0.00584889
 test epoch: 46150, loss: 0.00584774
train epoch: 46200, loss: 0.00376106
 test epoch: 46200, loss: 0.00584660
 test epoch: 46250, loss: 0.00584542
train epoch: 46300, loss: 0.00376038
 test epoch: 46300, loss: 0.00584420
 test epoch: 46350, loss: 0.00584304
train epoch: 46400, loss: 0.00375970
 test epoch: 46400, loss: 0.00584186
 test epoch: 46450, loss: 0.00584066
train epoch: 46500, loss: 0.00375901
 test epoch: 46500, loss: 0.00583946
 test epoch: 46550, loss: 0.00583827
train epoch: 46600, loss: 0.00375833
 test epoch: 46600, loss: 0.00583713
 test epoch: 46650, loss: 0.00583594
train epoch: 46700, loss: 0.00375764
 test epoch: 46700, loss: 0.00583478
 test epoch: 46750, loss: 0.00583364
train epoch: 46800, loss: 0.00375696
 test epoch: 46800, loss: 0.00583250
 test epoch: 46850, loss: 0.00583124
train epoch: 46900, loss: 0.00375629
 test epoch: 46900, loss: 0.00583008
 test epoch: 46950, loss: 0.00582892
train epoch: 47000, loss: 0.00375560
 test epoch: 47000, loss: 0.00582771
 test epoch: 47050, loss: 0.00582652
train epoch: 47100, loss: 0.00375494
 test epoch: 47100, loss: 0.00582538
 test epoch: 47150, loss: 0.00582424
train epoch: 47200, loss: 0.00375427
 test epoch: 47200, loss: 0.00582305
 test epoch: 47250, loss: 0.00582187
train epoch: 47300, loss: 0.00375359
 test epoch: 47300, loss: 0.00582066
 test epoch: 47350, loss: 0.00581952
train epoch: 47400, loss: 0.00375293
 test epoch: 47400, loss: 0.00581838
 test epoch: 47450, loss: 0.00581717
train epoch: 47500, loss: 0.00375225
 test epoch: 47500, loss: 0.00581600
 test epoch: 47550, loss: 0.00581479
train epoch: 47600, loss: 0.00375159
 test epoch: 47600, loss: 0.00581361
 test epoch: 47650, loss: 0.00581243
train epoch: 47700, loss: 0.00375092
 test epoch: 47700, loss: 0.00581124
 test epoch: 47750, loss: 0.00581002
train epoch: 47800, loss: 0.00375026
 test epoch: 47800, loss: 0.00580885
 test epoch: 47850, loss: 0.00580760
train epoch: 47900, loss: 0.00374959
 test epoch: 47900, loss: 0.00580642
 test epoch: 47950, loss: 0.00580522
train epoch: 48000, loss: 0.00374894
 test epoch: 48000, loss: 0.00580402
 test epoch: 48050, loss: 0.00580287
train epoch: 48100, loss: 0.00374828
 test epoch: 48100, loss: 0.00580159
 test epoch: 48150, loss: 0.00580035
train epoch: 48200, loss: 0.00374761
 test epoch: 48200, loss: 0.00579917
 test epoch: 48250, loss: 0.00579801
train epoch: 48300, loss: 0.00374696
 test epoch: 48300, loss: 0.00579678
 test epoch: 48350, loss: 0.00579561
train epoch: 48400, loss: 0.00374630
 test epoch: 48400, loss: 0.00579440
 test epoch: 48450, loss: 0.00579318
train epoch: 48500, loss: 0.00374565
 test epoch: 48500, loss: 0.00579195
 test epoch: 48550, loss: 0.00579074
train epoch: 48600, loss: 0.00374498
 test epoch: 48600, loss: 0.00578954
 test epoch: 48650, loss: 0.00578833
train epoch: 48700, loss: 0.00374434
 test epoch: 48700, loss: 0.00578713
 test epoch: 48750, loss: 0.00578591
train epoch: 48800, loss: 0.00374368
 test epoch: 48800, loss: 0.00578468
 test epoch: 48850, loss: 0.00578350
train epoch: 48900, loss: 0.00374303
 test epoch: 48900, loss: 0.00578230
 test epoch: 48950, loss: 0.00578101
train epoch: 49000, loss: 0.00374238
 test epoch: 49000, loss: 0.00577985
 test epoch: 49050, loss: 0.00577859
train epoch: 49100, loss: 0.00374173
 test epoch: 49100, loss: 0.00577734
 test epoch: 49150, loss: 0.00577611
train epoch: 49200, loss: 0.00374107
 test epoch: 49200, loss: 0.00577489
 test epoch: 49250, loss: 0.00577365
train epoch: 49300, loss: 0.00374042
 test epoch: 49300, loss: 0.00577242
 test epoch: 49350, loss: 0.00577120
train epoch: 49400, loss: 0.00373977
 test epoch: 49400, loss: 0.00577002
 test epoch: 49450, loss: 0.00576883
train epoch: 49500, loss: 0.00373911
 test epoch: 49500, loss: 0.00576760
 test epoch: 49550, loss: 0.00576642
train epoch: 49600, loss: 0.00373846
 test epoch: 49600, loss: 0.00576522
 test epoch: 49650, loss: 0.00576401
train epoch: 49700, loss: 0.00373781
 test epoch: 49700, loss: 0.00576278
 test epoch: 49750, loss: 0.00576159
train epoch: 49800, loss: 0.00373715
 test epoch: 49800, loss: 0.00576035
 test epoch: 49850, loss: 0.00575918
train epoch: 49900, loss: 0.00373650
 test epoch: 49900, loss: 0.00575796
 test epoch: 49950, loss: 0.00575677
train epoch: 50000, loss: 0.00373585
 test epoch: 50000, loss: 0.00575555
 test epoch: 50050, loss: 0.00575429
train epoch: 50100, loss: 0.00373519
 test epoch: 50100, loss: 0.00575309
 test epoch: 50150, loss: 0.00575189
train epoch: 50200, loss: 0.00373454
 test epoch: 50200, loss: 0.00575069
 test epoch: 50250, loss: 0.00574941
train epoch: 50300, loss: 0.00373389
 test epoch: 50300, loss: 0.00574823
 test epoch: 50350, loss: 0.00574699
train epoch: 50400, loss: 0.00373324
 test epoch: 50400, loss: 0.00574579
 test epoch: 50450, loss: 0.00574455
train epoch: 50500, loss: 0.00373258
 test epoch: 50500, loss: 0.00574337
 test epoch: 50550, loss: 0.00574215
train epoch: 50600, loss: 0.00373192
 test epoch: 50600, loss: 0.00574094
 test epoch: 50650, loss: 0.00573971
train epoch: 50700, loss: 0.00373127
 test epoch: 50700, loss: 0.00573854
 test epoch: 50750, loss: 0.00573729
train epoch: 50800, loss: 0.00373062
 test epoch: 50800, loss: 0.00573608
 test epoch: 50850, loss: 0.00573491
train epoch: 50900, loss: 0.00372997
 test epoch: 50900, loss: 0.00573365
 test epoch: 50950, loss: 0.00573242
train epoch: 51000, loss: 0.00372934
 test epoch: 51000, loss: 0.00573122
 test epoch: 51050, loss: 0.00573000
train epoch: 51100, loss: 0.00372868
 test epoch: 51100, loss: 0.00572874
 test epoch: 51150, loss: 0.00572751
train epoch: 51200, loss: 0.00372805
 test epoch: 51200, loss: 0.00572627
 test epoch: 51250, loss: 0.00572502
train epoch: 51300, loss: 0.00372739
 test epoch: 51300, loss: 0.00572373
 test epoch: 51350, loss: 0.00572248
train epoch: 51400, loss: 0.00372674
 test epoch: 51400, loss: 0.00572121
 test epoch: 51450, loss: 0.00571968
train epoch: 51500, loss: 0.00372609
 test epoch: 51500, loss: 0.00571828
 test epoch: 51550, loss: 0.00571700
train epoch: 51600, loss: 0.00372550
 test epoch: 51600, loss: 0.00571571
 test epoch: 51650, loss: 0.00571447
train epoch: 51700, loss: 0.00372490
 test epoch: 51700, loss: 0.00571324
 test epoch: 51750, loss: 0.00571205
train epoch: 51800, loss: 0.00372430
 test epoch: 51800, loss: 0.00571081
 test epoch: 51850, loss: 0.00570957
train epoch: 51900, loss: 0.00372371
 test epoch: 51900, loss: 0.00570837
 test epoch: 51950, loss: 0.00570716
train epoch: 52000, loss: 0.00372313
 test epoch: 52000, loss: 0.00570597
 test epoch: 52050, loss: 0.00570472
train epoch: 52100, loss: 0.00372253
 test epoch: 52100, loss: 0.00570350
 test epoch: 52150, loss: 0.00570234
train epoch: 52200, loss: 0.00372195
 test epoch: 52200, loss: 0.00570107
 test epoch: 52250, loss: 0.00569986
train epoch: 52300, loss: 0.00372136
 test epoch: 52300, loss: 0.00569865
 test epoch: 52350, loss: 0.00569744
train epoch: 52400, loss: 0.00372078
 test epoch: 52400, loss: 0.00569622
 test epoch: 52450, loss: 0.00569501
train epoch: 52500, loss: 0.00372019
 test epoch: 52500, loss: 0.00569385
 test epoch: 52550, loss: 0.00569266
train epoch: 52600, loss: 0.00371960
 test epoch: 52600, loss: 0.00569154
 test epoch: 52650, loss: 0.00569031
train epoch: 52700, loss: 0.00371901
 test epoch: 52700, loss: 0.00568910
 test epoch: 52750, loss: 0.00568791
train epoch: 52800, loss: 0.00371844
 test epoch: 52800, loss: 0.00568674
 test epoch: 52850, loss: 0.00568551
train epoch: 52900, loss: 0.00371786
 test epoch: 52900, loss: 0.00568433
 test epoch: 52950, loss: 0.00568319
train epoch: 53000, loss: 0.00371728
 test epoch: 53000, loss: 0.00568200
 test epoch: 53050, loss: 0.00568085
train epoch: 53100, loss: 0.00371670
 test epoch: 53100, loss: 0.00567963
 test epoch: 53150, loss: 0.00567847
train epoch: 53200, loss: 0.00371613
 test epoch: 53200, loss: 0.00567727
 test epoch: 53250, loss: 0.00567615
train epoch: 53300, loss: 0.00371555
 test epoch: 53300, loss: 0.00567490
 test epoch: 53350, loss: 0.00567372
train epoch: 53400, loss: 0.00371498
 test epoch: 53400, loss: 0.00567254
 test epoch: 53450, loss: 0.00567135
train epoch: 53500, loss: 0.00371441
 test epoch: 53500, loss: 0.00567015
 test epoch: 53550, loss: 0.00566900
train epoch: 53600, loss: 0.00371385
 test epoch: 53600, loss: 0.00566779
 test epoch: 53650, loss: 0.00566660
train epoch: 53700, loss: 0.00371328
 test epoch: 53700, loss: 0.00566540
 test epoch: 53750, loss: 0.00566420
train epoch: 53800, loss: 0.00371271
 test epoch: 53800, loss: 0.00566302
 test epoch: 53850, loss: 0.00566181
train epoch: 53900, loss: 0.00371217
 test epoch: 53900, loss: 0.00566060
 test epoch: 53950, loss: 0.00565945
train epoch: 54000, loss: 0.00371160
 test epoch: 54000, loss: 0.00565827
 test epoch: 54050, loss: 0.00565707
train epoch: 54100, loss: 0.00371105
 test epoch: 54100, loss: 0.00565587
 test epoch: 54150, loss: 0.00565469
train epoch: 54200, loss: 0.00371050
 test epoch: 54200, loss: 0.00565351
 test epoch: 54250, loss: 0.00565228
train epoch: 54300, loss: 0.00370995
 test epoch: 54300, loss: 0.00565108
 test epoch: 54350, loss: 0.00564986
train epoch: 54400, loss: 0.00370941
 test epoch: 54400, loss: 0.00564872
 test epoch: 54450, loss: 0.00564749
train epoch: 54500, loss: 0.00370886
 test epoch: 54500, loss: 0.00564630
 test epoch: 54550, loss: 0.00564511
train epoch: 54600, loss: 0.00370833
 test epoch: 54600, loss: 0.00564396
 test epoch: 54650, loss: 0.00564271
train epoch: 54700, loss: 0.00370780
 test epoch: 54700, loss: 0.00564152
 test epoch: 54750, loss: 0.00564033
train epoch: 54800, loss: 0.00370726
 test epoch: 54800, loss: 0.00563911
 test epoch: 54850, loss: 0.00563799
train epoch: 54900, loss: 0.00370673
 test epoch: 54900, loss: 0.00563678
 test epoch: 54950, loss: 0.00563560
train epoch: 55000, loss: 0.00370620
 test epoch: 55000, loss: 0.00563440
 test epoch: 55050, loss: 0.00563317
train epoch: 55100, loss: 0.00370568
 test epoch: 55100, loss: 0.00563203
 test epoch: 55150, loss: 0.00563079
train epoch: 55200, loss: 0.00370516
 test epoch: 55200, loss: 0.00562959
 test epoch: 55250, loss: 0.00562835
train epoch: 55300, loss: 0.00370464
 test epoch: 55300, loss: 0.00562716
 test epoch: 55350, loss: 0.00562600
train epoch: 55400, loss: 0.00370412
 test epoch: 55400, loss: 0.00562481
 test epoch: 55450, loss: 0.00562373
train epoch: 55500, loss: 0.00370362
 test epoch: 55500, loss: 0.00562251
 test epoch: 55550, loss: 0.00562135
train epoch: 55600, loss: 0.00370311
 test epoch: 55600, loss: 0.00562015
 test epoch: 55650, loss: 0.00561899
train epoch: 55700, loss: 0.00370262
 test epoch: 55700, loss: 0.00561784
 test epoch: 55750, loss: 0.00561659
train epoch: 55800, loss: 0.00370211
 test epoch: 55800, loss: 0.00561542
 test epoch: 55850, loss: 0.00561429
train epoch: 55900, loss: 0.00370162
 test epoch: 55900, loss: 0.00561316
 test epoch: 55950, loss: 0.00561195
train epoch: 56000, loss: 0.00370113
 test epoch: 56000, loss: 0.00561079
 test epoch: 56050, loss: 0.00560963
train epoch: 56100, loss: 0.00370064
 test epoch: 56100, loss: 0.00560838
 test epoch: 56150, loss: 0.00560728
train epoch: 56200, loss: 0.00370015
 test epoch: 56200, loss: 0.00560618
 test epoch: 56250, loss: 0.00560503
train epoch: 56300, loss: 0.00369968
 test epoch: 56300, loss: 0.00560394
 test epoch: 56350, loss: 0.00560281
train epoch: 56400, loss: 0.00369920
 test epoch: 56400, loss: 0.00560169
 test epoch: 56450, loss: 0.00560049
train epoch: 56500, loss: 0.00369872
 test epoch: 56500, loss: 0.00559938
 test epoch: 56550, loss: 0.00559827
train epoch: 56600, loss: 0.00369826
 test epoch: 56600, loss: 0.00559717
 test epoch: 56650, loss: 0.00559607
train epoch: 56700, loss: 0.00369779
 test epoch: 56700, loss: 0.00559493
 test epoch: 56750, loss: 0.00559379
train epoch: 56800, loss: 0.00369733
 test epoch: 56800, loss: 0.00559272
 test epoch: 56850, loss: 0.00559166
train epoch: 56900, loss: 0.00369688
 test epoch: 56900, loss: 0.00559052
 test epoch: 56950, loss: 0.00558937
train epoch: 57000, loss: 0.00369643
 test epoch: 57000, loss: 0.00558829
 test epoch: 57050, loss: 0.00558719
train epoch: 57100, loss: 0.00369599
 test epoch: 57100, loss: 0.00558607
 test epoch: 57150, loss: 0.00558502
train epoch: 57200, loss: 0.00369554
 test epoch: 57200, loss: 0.00558392
 test epoch: 57250, loss: 0.00558288
train epoch: 57300, loss: 0.00369511
 test epoch: 57300, loss: 0.00558182
 test epoch: 57350, loss: 0.00558075
train epoch: 57400, loss: 0.00369468
 test epoch: 57400, loss: 0.00557967
 test epoch: 57450, loss: 0.00557861
train epoch: 57500, loss: 0.00369425
 test epoch: 57500, loss: 0.00557756
 test epoch: 57550, loss: 0.00557651
train epoch: 57600, loss: 0.00369383
 test epoch: 57600, loss: 0.00557546
 test epoch: 57650, loss: 0.00557440
train epoch: 57700, loss: 0.00369341
 test epoch: 57700, loss: 0.00557342
 test epoch: 57750, loss: 0.00557238
train epoch: 57800, loss: 0.00369299
 test epoch: 57800, loss: 0.00557139
 test epoch: 57850, loss: 0.00557035
train epoch: 57900, loss: 0.00369259
 test epoch: 57900, loss: 0.00556929
 test epoch: 57950, loss: 0.00556830
train epoch: 58000, loss: 0.00369218
 test epoch: 58000, loss: 0.00556737
 test epoch: 58050, loss: 0.00556632
train epoch: 58100, loss: 0.00369179
 test epoch: 58100, loss: 0.00556540
 test epoch: 58150, loss: 0.00556436
train epoch: 58200, loss: 0.00369139
 test epoch: 58200, loss: 0.00556339
 test epoch: 58250, loss: 0.00556245
train epoch: 58300, loss: 0.00369100
 test epoch: 58300, loss: 0.00556148
 test epoch: 58350, loss: 0.00556056
train epoch: 58400, loss: 0.00369063
 test epoch: 58400, loss: 0.00555963
 test epoch: 58450, loss: 0.00555861
train epoch: 58500, loss: 0.00369025
 test epoch: 58500, loss: 0.00555768
 test epoch: 58550, loss: 0.00555678
train epoch: 58600, loss: 0.00368988
 test epoch: 58600, loss: 0.00555588
 test epoch: 58650, loss: 0.00555492
train epoch: 58700, loss: 0.00368951
 test epoch: 58700, loss: 0.00555406
 test epoch: 58750, loss: 0.00555323
train epoch: 58800, loss: 0.00368914
 test epoch: 58800, loss: 0.00555234
 test epoch: 58850, loss: 0.00555140
train epoch: 58900, loss: 0.00368878
 test epoch: 58900, loss: 0.00555050
 test epoch: 58950, loss: 0.00554964
train epoch: 59000, loss: 0.00368843
 test epoch: 59000, loss: 0.00554874
 test epoch: 59050, loss: 0.00554786
train epoch: 59100, loss: 0.00368808
 test epoch: 59100, loss: 0.00554703
 test epoch: 59150, loss: 0.00554624
train epoch: 59200, loss: 0.00368774
 test epoch: 59200, loss: 0.00554539
 test epoch: 59250, loss: 0.00554450
train epoch: 59300, loss: 0.00368740
 test epoch: 59300, loss: 0.00554370
 test epoch: 59350, loss: 0.00554285
train epoch: 59400, loss: 0.00368707
 test epoch: 59400, loss: 0.00554204
 test epoch: 59450, loss: 0.00554123
train epoch: 59500, loss: 0.00368674
 test epoch: 59500, loss: 0.00554041
 test epoch: 59550, loss: 0.00553964
train epoch: 59600, loss: 0.00368643
 test epoch: 59600, loss: 0.00553886
 test epoch: 59650, loss: 0.00553808
train epoch: 59700, loss: 0.00368610
 test epoch: 59700, loss: 0.00553737
 test epoch: 59750, loss: 0.00553660
train epoch: 59800, loss: 0.00368579
 test epoch: 59800, loss: 0.00553587
 test epoch: 59850, loss: 0.00553511
train epoch: 59900, loss: 0.00368548
 test epoch: 59900, loss: 0.00553439
 test epoch: 59950, loss: 0.00553366
train epoch: 60000, loss: 0.00368518
 test epoch: 60000, loss: 0.00553300
 test epoch: 60050, loss: 0.00553224
train epoch: 60100, loss: 0.00368489
 test epoch: 60100, loss: 0.00553159
 test epoch: 60150, loss: 0.00553092
train epoch: 60200, loss: 0.00368460
 test epoch: 60200, loss: 0.00553019
 test epoch: 60250, loss: 0.00552952
train epoch: 60300, loss: 0.00368432
 test epoch: 60300, loss: 0.00552886
 test epoch: 60350, loss: 0.00552821
train epoch: 60400, loss: 0.00368403
 test epoch: 60400, loss: 0.00552753
 test epoch: 60450, loss: 0.00552693
train epoch: 60500, loss: 0.00368376
 test epoch: 60500, loss: 0.00552635
 test epoch: 60550, loss: 0.00552575
train epoch: 60600, loss: 0.00368349
 test epoch: 60600, loss: 0.00552509
---------------------------------------------------------------------------
KeyboardInterrupt                         Traceback (most recent call last)
Cell In [53], line 9
      7 rng, input_rng = jax.random.split(rng)
      8 # Run an optimization step over a training batch
----> 9 state = train_epoch(state, q_train_dataset, x_train_dataset, batch_size, epoch, input_rng)
     10 # Evaluate on the test set after each training epoch
     11 test_loss= eval_model(state.params, q_test_dataset, x_test_dataset)

Cell In [20], line 354, in train_epoch(state, train_ds, xtrain_ds, batch_size, epoch, rng)
    352 batch_metrics = []
    353 for batch, xbatch in zip(train_ds_batched, xtrain_ds_batched):
--> 354   state, metrics = train_step(state, batch, xbatch)
    355   batch_metrics.append(metrics)
    357 # compute mean of metrics across each batch in epoch.

File ~/py_dspy/lib/python3.9/site-packages/flax/core/frozen_dict.py:159, in FrozenDict.tree_unflatten(cls, _, data)
    152   """Flattens this FrozenDict.
    153 
    154   Returns:
    155     A flattened version of this FrozenDict instance.
    156   """
    157   return (self._dict,), ()
--> 159 @classmethod
    160 def tree_unflatten(cls, _, data):
    161   # data is already deep copied due to tree map mechanism
    162   # we can skip the deep copy in the constructor
    163   return cls(*data, __unsafe_skip_copy__=True)

KeyboardInterrupt: 
In [56]:
Nsim = 35000

Qpred = encode_predict(state.params, QnormTrainData[:Nsim,:,:])
Xpred = decode_predict(state.params, Qpred)
print(Qpred.shape)

sample = np.array([0,10,30])

plt.figure(figsize = (16,12))

plt.subplot(3,2,1)
plt.plot(Qpred[:,:,0])
plt.plot(QnormTrainData[:Nsim,:,0], 'k--')
#plt.plot(normTrainData[:Nsim,IX_NODELAY[0],0], 'k--')
#plt.plot(QnormTrainData[:Nsim,:,0], 'm--')

plt.subplot(3,2,2)
plt.plot(Xpred[:,sample,0])
plt.plot(XnormTrainData[:Nsim,sample,0], 'k--')

plt.subplot(3,2,3)
plt.plot(Qpred[:,:,1])
#plt.plot(QnormTrainData[:1000,:,1], 'k--')
#dQpred = np.diff(Qpred[:,:,0], axis = 0)/dt
#plt.plot(dQpred, 'm--')

plt.subplot(3,2,4)
plt.plot(Xpred[:,sample,1])
plt.plot(XnormTrainData[:Nsim,sample,1], 'k--')

plt.subplot(3,2,5)
plt.plot(Qpred[:,:,2])
#plt.plot(QnormTrainData[:1000,:,2], 'k--')

plt.subplot(3,2,6)
plt.plot(Xpred[:,sample,2])
plt.plot(XnormTrainData[:Nsim,sample,2], 'k--')

plt.show()
(16384, 6, 4)
In [57]:
loss_fn_(state.params, QnormTrainData[:Nsim,:,:], XnormTrainData[:Nsim,:,:])
(16384, 6) (16384, 64)
(16384, 6) (16384, 6)
(6, 6) (6, 6) (6, 6) (1, 6) (6, 6)
(6, 6) (6, 6) (6, 6) (1, 6) (6, 6)
(6, 6) (6, 6) (6, 6) (1, 6) (6, 6)
(6, 6) (6, 6) (6, 6) (1, 6) (6, 6)
(6, 6) (6, 6) (6, 6) (6,) (6, 6)
(6, 6) (6, 6) (6, 6) (6,) (6, 6)
(6, 6) (6, 6) (6, 6) (6,) (6, 6)
(6, 6) (6, 6) (6, 6) (6,) (6, 6)
(6, 6) (6, 6) (6, 6) (1, 6) (6, 6)
(6, 6) (6, 6) (6, 6) (6,) (6, 6)
(6, 6) (6, 6) (6, 6) (1, 6) (6, 6)
Out[57]:
DeviceArray([ 8.6628074e-07,  4.5068827e-12,  8.6628074e-07,
              2.8949065e-12,  8.6628077e-10,  1.5462806e-02,
              1.6548460e-03,  6.4100551e-11, -2.4795823e-05],            dtype=float32)
In [58]:
That,That_lin,That_nl = predict_kinetic(state.params, QnormTrainData[:Nsim,:,:])
Vhat,Vhat_lin,Vhat_nl = predict_potential(state.params, QnormTrainData[:Nsim,:,:])
Hhat = predict_Hamiltonian(state.params, QnormTrainData[:Nsim,:,:])

plt.figure(figsize = (16,12))
plt.subplot(3,1,1)
plt.plot(That_lin)
plt.plot(That_nl)

plt.title('kinetic energy')
plt.legend(['linear', 'nnet'])

plt.subplot(3,1,2)
plt.plot(Vhat_lin)
plt.plot(Vhat_nl)

plt.title('potential energy')
plt.legend(['linear', 'nnet'])

plt.subplot(3,1,3)
plt.plot(Vhat_lin + That_lin)
#plt.plot(Vhat + That)
plt.plot(Hhat[:,0,0], 'k--')

plt.title('total mechanical energy')


plt.show()
(6, 6) (6, 6) (6, 6) (6,) (6, 6)
(6, 6) (6, 6) (6, 6) (1, 6) (6, 6)
(6, 6) (6, 6) (6, 6) (1, 6) (6, 6)
In [ ]:
 
In [60]:
plt.figure(figsize = (12,4))
plt.subplot(1,2,1)
plt.plot(That_lin - Vhat_lin)
plt.plot(That - Vhat, 'k--', alpha = 0.8)
ax = plt.gca()
ax2 = ax.twinx()
ax2.plot((That - Vhat).reshape((-1,)) - (That_lin - Vhat_lin).reshape((-1,)), 'm')
ax2.set_ylim([0, 0.5])

plt.title('Lagrangian (T - V)')
ax.legend(['linear', 'nnet'])
ax.set_ylabel('$L$')
ax2.set_ylabel('$L_{nn}$')


plt.subplot(1,2,2)
plt.plot(Vhat_lin + That_lin)
#plt.plot(Vhat + That)
plt.plot(Hhat[:,0,0], 'k--')
plt.legend(['linear', 'nnet'])
plt.ylabel('$H$')
plt.ylim([10,13])

plt.title('total mechanical energy (T + V)')

plt.tight_layout()
plt.show()
In [30]:
Mhat_pred,Khat_pred,Rhat_pred,Mhat_nl,Khat_nl,Chat_nl = predict_MKC(state.params, QnormTrainData[:Nsim,:,:])
(6, 6) (6, 6) (6, 6) (6,) (6, 6)
In [61]:
plt.figure(figsize = (10,10))

labels = ['$\mathbf{M}$', '$\\delta \mathbf{M}(q)$','$\mathbf{K}$', '$\\delta \mathbf{K}(q)$','$\mathbf{C}$', '$\\delta \mathbf{C}(\dot{q})$']

plt.subplot(3,2,1)
plt.imshow(MHAT)
plt.colorbar()
plt.title(labels[0])
plt.axis('off')

plt.subplot(3,2,2)
plt.imshow(Mhat_nl[10])
plt.colorbar()
plt.title(labels[1])
plt.axis('off')

plt.subplot(3,2,3)
plt.imshow(KHAT)
plt.colorbar()
plt.title(labels[2])
plt.axis('off')

plt.subplot(3,2,4)
plt.imshow(Khat_nl[100])
plt.colorbar()
plt.title(labels[3])
plt.axis('off')

plt.subplot(3,2,5)
plt.imshow(CHAT)
plt.colorbar()
plt.title(labels[4])
plt.axis('off')

plt.subplot(3,2,6)
plt.imshow(Chat_nl[10])
plt.colorbar()
plt.title(labels[5])
plt.axis('off')

plt.tight_layout()
plt.show()
In [62]:
plt.figure(figsize = (10,10))

plt.subplot(2,2,1)
plt.plot(XnormTrainData[:Nsim,:,1])
plt.plot(np.diff(XnormTrainData[:Nsim,:,0], axis = 0)/dt, 'k--')

plt.subplot(2,2,2)
plt.plot(XnormTrainData[:Nsim,:,2])
plt.plot(np.diff(XnormTrainData[:Nsim,:,1], axis = 0)/dt, 'k--')

plt.subplot(2,2,3)
plt.plot(Xpred[:,:,1])
plt.plot(np.diff(Xpred[:,:,0], axis = 0)/dt, 'k--')

plt.subplot(2,2,4)
plt.plot(Xpred[:,:,2])
plt.plot(np.diff(Xpred[:,:,1], axis = 0)/dt, 'k--')

plt.show()
In [63]:
ddq_hat = predict_ddq(state.params, Qpred[:])

plt.figure()
plt.plot(ddq_hat)
plt.plot(Qpred[:,:,2], 'k--')

plt.show()
(6, 6) (6, 6) (6, 6) (1, 6) (6, 6)
(6, 6) (6, 6) (6, 6) (1, 6) (6, 6)
(6, 6) (6, 6) (6, 6) (1, 6) (6, 6)
(6, 6) (6, 6) (6, 6) (6,) (6, 6)
(6, 6) (6, 6) (6, 6) (6,) (6, 6)
In [64]:
from scipy.signal import butter, lfilter, freqz
from scipy.interpolate import interp1d 

Qext_pred = QnormTrainData[:,:,3]

print(Qext_pred.shape)



time_vec_int = t[:].reshape((-1,))
fbar_int = interp1d(time_vec_int[:Qext_pred.shape[0]], Qext_pred, axis = 0)
print(time_vec_int.shape, Qext_pred.shape)
(16384, 6)
(16384,) (16384, 6)
In [65]:
qICs = np.concatenate([Qpred[0,:,0], Qpred[0,:,1]])
print(qICs)
[ 0.          0.          0.          0.          0.          0.
  3.5880363  -1.92512     2.4119582  -0.00745718 -0.01118277  0.02246637]
In [71]:
import time
from scipy.interpolate import interp1d 
import scipy.signal 

dof = LATENT

def tf_dydx(q, dq, f, params):
    return f_diff(params, q, dq, f)
    

def learned_q(x, t, params, fbar):
    x = np.array(x)
    t = np.array([t])
    
    dof = LATENT
    
    q = x[:dof]
    dq = x[dof:]
    
    f = fbar(t).reshape((-1,))

    
    q = q.reshape((-1,LATENT))
    dq = dq.reshape((-1,LATENT))
    f = f.reshape((-1,LATENT))
    
    dydx = tf_dydx(q, dq, f, params)
    
    #qqdot = x.reshape((1, dof*2)) #np.concatenate([x, f], axis = 1)
    
    #dydx = model.F_layers_q(qqdot).numpy()
    
    return dydx.flatten()


Nint = Ntt*test_factor-10 #18000-10

import time

start_ = time.time()
qout = scipy.integrate.odeint(learned_q, qICs, time_vec_int[:Nint], (state.params,fbar_int))
end = time.time()
print(end - start_)
1.627075433731079
In [72]:
plt.figure(figsize = (10,10))

print(dof)

xpred = qout[:, :dof]
vpred = qout[:, dof:]
apred = f_diff(state.params, xpred, vpred, Qext_pred[:Nint,:])[:,dof:]

plt.subplot(3, 1, 1)

plt.plot(time_vec_int[:Nint], Qpred[:Nint,:,0], linewidth = 4, alpha = 0.6)
plt.plot(time_vec_int[:Nint], xpred, 'k')
plt.ylabel('Displacement')
plt.xlabel('Time, s')

plt.subplot(3, 1, 2)

plt.plot(time_vec_int[:Nint], Qpred[:Nint,:,1], linewidth = 4, alpha = 0.6)
plt.plot(time_vec_int[:Nint], vpred, 'k')
plt.ylabel('Velocity')
plt.xlabel('Time, s')
#plt.xlim([0, 0.1])

plt.subplot(3, 1, 3)

plt.plot(time_vec_int[:Nint], Qpred[:Nint,:,2], linewidth = 4, alpha = 0.6)
plt.plot(time_vec_int[:Nint], apred, 'k')
plt.ylabel('Acceleration')
plt.xlabel('Time, s')
#plt.xlim([0, 0.1])

plt.tight_layout()
#plt.savefig('latent_displacements_time_NL.tif')
#plt.savefig('/content/drive/MyDrive/UCSD_research/jax/chain_of_masses/time_integrated_latent_kprime2p0.png')
plt.show()
6
(6, 6) (6, 6) (6, 6) (1, 6) (6, 6)
(6, 6) (6, 6) (6, 6) (1, 6) (6, 6)
(6, 6) (6, 6) (6, 6) (1, 6) (6, 6)
(6, 6) (6, 6) (6, 6) (6,) (6, 6)
(6, 6) (6, 6) (6, 6) (6,) (6, 6)
In [73]:
Qpred_int = np.concatenate([xpred[:,:,np.newaxis],
                            vpred[:,:,np.newaxis],
                            apred[:,:,np.newaxis],
                            Qext_pred[:Nint,:,np.newaxis]], axis = -1)
Xpred_int = decode_predict(state.params, Qpred_int)
print(Xpred_int.shape)
(16374, 64, 3)
In [74]:
Xpred_Vq_q = np.dot(VMAT, Qpred_int[:,:,0].T).T
Xpred_Vq_dq = np.dot(VMAT, Qpred_int[:,:,1].T).T

Xpred_Vq = np.concatenate([Xpred_Vq_q[:,:,np.newaxis], Xpred_Vq_dq[:,:,np.newaxis]], axis = 2)

print(Xpred_Vq.shape)
(16374, 64, 2)
In [75]:
np.random.seed(3)
plt.figure(figsize = (10,8))

ix = np.random.randint(0,64,3)
#start = 20
Nplot = Nint

plt.subplot(3, 1, 1)

plt.plot(time_vec_int[:Nplot], XnormTrainData[:Nplot,ix,0], linewidth = 2, alpha = 0.6)
plt.plot(time_vec_int[:Nplot], Xpred_int[:Nplot,ix,0], 'k--')
#plt.vlines(time_vec_int[8000], ymin = -0.2, ymax = 0.2, color = 'r', linestyle = '--')
plt.ylabel('Displacement')
plt.xlabel('Time, s')
plt.title('NNet')


plt.subplot(3, 1, 2)

plt.plot(time_vec_int[:Nplot], XnormTrainData[:Nplot,ix,0], linewidth = 2, alpha = 0.6)
plt.plot(time_vec_int[:Nplot], Qrr.T[start:Nplot+start,ix], 'k--')
#plt.vlines(time_vec_int[8000], ymin = -0.2, ymax = 0.2, color = 'r', linestyle = '--')
plt.ylabel('Displacement')
plt.xlabel('Time, s')
plt.title('ROM')

#errors_disp_nn = XnormTrainData[:Nplot,:,0] - Xpred_int[:Nplot,:,0]
#errors_disp_rom = XnormTrainData[:Nplot,:,0] - Qrr.T[start:Nplot+start,:]

#plt.subplot(3,1,3)
#for ii in range(64):
#    plt.stem(time_vec_int[:Nplot], errors_disp_nn[:, ii], linefmt = 'blue')
#    plt.stem(time_vec_int[:Nplot], errors_disp_rom[:, ii], linefmt = 'red')


plt.tight_layout()
plt.show()
In [78]:
np.random.seed(3)
plt.figure(figsize = (10,8))

ix = np.random.randint(0,64,3)
#start = 20
Nplot = Nint

plt.subplot(3, 1, 1)

plt.plot(time_vec_int[:Nplot], XnormTrainData[:Nplot,ix,0], linewidth = 2, alpha = 0.6)
plt.plot(time_vec_int[:Nplot], Xpred_int[:Nplot,ix,0], 'k--')
plt.plot(time_vec_int[:Nplot], Xpred_Vq[:Nplot,ix,0], 'g--')
#plt.vlines(time_vec_int[8000], ymin = -0.2, ymax = 0.2, color = 'r', linestyle = '--')
plt.ylabel('Displacement')
plt.xlabel('Time, s')
plt.title('NNet')


plt.subplot(3, 1, 2)

plt.plot(time_vec_int[:Nplot], XnormTrainData[:Nplot,ix,0], linewidth = 2, alpha = 0.6)
plt.plot(time_vec_int[:Nplot], Qrr.T[start:Nplot+start,ix], 'k--')
#plt.vlines(time_vec_int[8000], ymin = -0.2, ymax = 0.2, color = 'r', linestyle = '--')
plt.ylabel('Displacement')
plt.xlabel('Time, s')
plt.title('ROM')

#errors_disp_nn = XnormTrainData[:Nplot,:,0] - Xpred_int[:Nplot,:,0]
#errors_disp_rom = XnormTrainData[:Nplot,:,0] - Qrr.T[start:Nplot+start,:]

#plt.subplot(3,1,3)
#for ii in range(64):
#    plt.stem(time_vec_int[:Nplot], errors_disp_nn[:, ii], linefmt = 'blue')
#    plt.stem(time_vec_int[:Nplot], errors_disp_rom[:, ii], linefmt = 'red')


plt.tight_layout()
plt.show()
In [79]:
errors_disp_nn = np.sqrt(np.mean((XnormTrainData[:Nplot,:,0] - Xpred_int[:Nplot,:,0])**2))
errors_disp_rom = np.sqrt(np.mean((XnormTrainData[:Nplot,:,0] - Qrr.T[start:Nplot+start,:])**2))
errors_disp_nnVq = np.sqrt(np.mean((XnormTrainData[:Nplot,:,0] - Xpred_Vq[:Nplot,:,0])**2))

print(errors_disp_nn, errors_disp_rom, errors_disp_nnVq)
0.0019257539 0.0068629194263998185 0.002322670679603171
In [80]:
dQrr = np.diff(Qrr.T, axis = 0)/dt
dQrr = np.insert(dQrr, 0, XnormTrainData[0,:,1], axis = 0)
#dQrr = dQrr + XnormTrainData[0,:,1]
dQrr.shape
Out[80]:
(16386, 64)
In [81]:
np.random.seed(3)
plt.figure(figsize = (10,8))

ix = np.random.randint(0,64,3)
#start = 20
Nplot = Nint

plt.subplot(3, 1, 1)

plt.plot(time_vec_int[:Nplot], XnormTrainData[:Nplot,ix,1], linewidth = 2, alpha = 0.6)
plt.plot(time_vec_int[:Nplot], Xpred_int[:Nplot,ix,1], 'k--')
#plt.vlines(time_vec_int[8000], ymin = -0.2, ymax = 0.2, color = 'r', linestyle = '--')
plt.ylabel('Velocity')
plt.xlabel('Time, s')
plt.title('NNet')


plt.subplot(3, 1, 2)

plt.plot(time_vec_int[:Nplot], XnormTrainData[:Nplot,ix,1], linewidth = 2, alpha = 0.6)
plt.plot(time_vec_int[:Nplot], dQrr[start:Nplot+start,ix], 'k--')
#plt.vlines(time_vec_int[8000], ymin = -0.2, ymax = 0.2, color = 'r', linestyle = '--')
plt.ylabel('Velocity')
plt.xlabel('Time, s')
plt.title('ROM')

#errors_disp_nn = XnormTrainData[:Nplot,:,0] - Xpred_int[:Nplot,:,0]
#errors_disp_rom = XnormTrainData[:Nplot,:,0] - Qrr.T[start:Nplot+start,:]

#plt.subplot(3,1,3)
#for ii in range(64):
#    plt.stem(time_vec_int[:Nplot], errors_disp_nn[:, ii], linefmt = 'blue')
#    plt.stem(time_vec_int[:Nplot], errors_disp_rom[:, ii], linefmt = 'red')


plt.tight_layout()
plt.show()
In [82]:
np.random.seed(3)
plt.figure(figsize = (10,8))

ix = np.random.randint(0,64,3)
#start = 20
Nplot = Nint

plt.subplot(3, 1, 1)

plt.plot(time_vec_int[:Nplot], XnormTrainData[:Nplot,ix,1], linewidth = 2, alpha = 0.6)
plt.plot(time_vec_int[:Nplot], Xpred_int[:Nplot,ix,1], 'k--')
plt.plot(time_vec_int[:Nplot], Xpred_Vq[:Nplot,ix,1], 'g--')
#plt.vlines(time_vec_int[8000], ymin = -0.2, ymax = 0.2, color = 'r', linestyle = '--')
plt.ylabel('Velocity')
plt.xlabel('Time, s')
plt.title('NNet')


plt.subplot(3, 1, 2)

plt.plot(time_vec_int[:Nplot], XnormTrainData[:Nplot,ix,1], linewidth = 2, alpha = 0.6)
plt.plot(time_vec_int[:Nplot], dQrr[start:Nplot+start,ix], 'k--')
#plt.vlines(time_vec_int[8000], ymin = -0.2, ymax = 0.2, color = 'r', linestyle = '--')
plt.ylabel('Velocity')
plt.xlabel('Time, s')
plt.title('ROM')

#errors_disp_nn = XnormTrainData[:Nplot,:,0] - Xpred_int[:Nplot,:,0]
#errors_disp_rom = XnormTrainData[:Nplot,:,0] - Qrr.T[start:Nplot+start,:]

#plt.subplot(3,1,3)
#for ii in range(64):
#    plt.stem(time_vec_int[:Nplot], errors_disp_nn[:, ii], linefmt = 'blue')
#    plt.stem(time_vec_int[:Nplot], errors_disp_rom[:, ii], linefmt = 'red')


plt.tight_layout()
plt.show()
In [83]:
errors_vel_nn = np.sqrt(np.mean((XnormTrainData[:Nplot,:,1] - Xpred_int[:Nplot,:,1])**2))
errors_vel_rom = np.sqrt(np.mean((XnormTrainData[:Nplot,:,1] - dQrr[start:Nplot+start,:])**2))
errors_vel_nnVq = np.sqrt(np.mean((XnormTrainData[:Nplot,:,1] - Xpred_Vq[:Nplot,:,1])**2))

print(errors_vel_nn, errors_vel_rom, errors_vel_nnVq)
0.036848325 0.12551573807078967 0.0535833302806473
In [84]:
np.random.seed(3)
plt.figure(figsize = (10,8))

ix = np.random.randint(0,64,3)
ix = 2
start = 0
Nplot = Nint


plt.subplot(2, 2, 1)

plt.plot(time_vec_int[:Nplot], XnormTrainData[:Nplot,ix,0], linewidth = 1, alpha = 0.7)
plt.plot(time_vec_int[:Nplot], Xpred_int[:Nplot,ix,0], 'r', alpha = 0.5)
plt.plot(time_vec_int[:Nplot], Xpred_Vq[:Nplot,ix,0], 'g--', alpha = 0.5)
plt.vlines(time_vec_int[Ntt], ymin = -0.06, ymax = 0.06, color = 'k', linestyle = '--')
plt.legend(['Truth', 'Pred mse (nl): {0:.4f}'.format(errors_disp_nn), 'Pred mse (lin): {0:.4f}'.format(errors_disp_nnVq)])
plt.ylabel('Displacement')
plt.xlabel('Time, s')
plt.title('NNet')


plt.subplot(2, 2, 2)

plt.plot(time_vec_int[:Nplot], XnormTrainData[:Nplot,ix,0], linewidth = 1, alpha = 0.7)
plt.plot(time_vec_int[:Nplot],  Qrr.T[start:Nplot+start,ix], 'r', alpha = 0.6)
plt.legend(['Truth', 'Pred mse: {0:.4f}'.format(errors_disp_rom)])
plt.vlines(time_vec_int[Ntt], ymin = -0.06, ymax = 0.06, color = 'k', linestyle = '--')
plt.ylabel('Displacement')
plt.xlabel('Time, s')
plt.title('ROM')



plt.subplot(2, 2, 3)

plt.plot(time_vec_int[:Nplot], XnormTrainData[:Nplot,ix,1], linewidth = 1, alpha = 0.7)
plt.plot(time_vec_int[:Nplot], Xpred_int[:Nplot,ix,1], 'r', alpha = 0.5)
plt.plot(time_vec_int[:Nplot], Xpred_Vq[:Nplot,ix,1], 'g--', alpha = 0.5)
plt.legend(['Truth', 'Pred mse (nl): {0:.4f}'.format(errors_vel_nn), 'Pred mse (lin): {0:.4f}'.format(errors_vel_nnVq)])
plt.vlines(time_vec_int[Ntt], ymin = -0.6, ymax = 0.6, color = 'k', linestyle = '--')
plt.ylabel('Velocity')
plt.xlabel('Time, s')
plt.title('NNet')


plt.subplot(2, 2, 4)

plt.plot(time_vec_int[:Nplot], XnormTrainData[:Nplot,ix,1], linewidth = 1, alpha = 0.7)
plt.plot(time_vec_int[:Nplot], dQrr[start:Nplot+start,ix], 'r', alpha = 0.6)
plt.legend(['Truth', 'Pred mse: {0:.4f}'.format(errors_vel_rom)])
plt.vlines(time_vec_int[Ntt], ymin = -0.6, ymax = 0.6, color = 'k', linestyle = '--')
plt.ylabel('Velocity')
plt.xlabel('Time, s')
plt.title('ROM')

#errors_disp_nn = XnormTrainData[:Nplot,:,0] - Xpred_int[:Nplot,:,0]
#errors_disp_rom = XnormTrainData[:Nplot,:,0] - Qrr.T[start:Nplot+start,:]

#plt.subplot(3,1,3)
#for ii in range(64):
#    plt.stem(time_vec_int[:Nplot], errors_disp_nn[:, ii], linefmt = 'blue')
#    plt.stem(time_vec_int[:Nplot], errors_disp_rom[:, ii], linefmt = 'red')


plt.tight_layout()
plt.show()
In [85]:
np.random.seed(3)
plt.figure(figsize = (10,4))

ix = np.random.randint(0,64,1)
#start = 20
Nplot = Nint

plt.subplot(1, 2, 1)
plt.title('NNet')
plt.plot(XnormTrainData[:Nplot,ix,0], XnormTrainData[:Nplot,ix,1])
plt.plot(Xpred_int[:Nplot,ix,0], Xpred_int[:Nplot,ix,1], 'k')
plt.xlabel('$x$')
plt.ylabel('$\dot{x}$')
plt.legend(['Truth', 'Prediction'])

plt.subplot(1, 2, 2)
plt.title('ROM')
plt.plot(XnormTrainData[:Nplot,ix,0], XnormTrainData[:Nplot,ix,1])
plt.plot(Qrr.T[start:Nplot+start,ix], dQrr[start:Nplot+start,ix], 'k')
plt.xlabel('$x$')
plt.ylabel('$\dot{x}$')
plt.legend(['Truth', 'Prediction'])

plt.tight_layout()
plt.show()
In [86]:
np.random.seed(3)
plt.figure(figsize = (10,4))

ix = np.random.randint(0,64,1)
#start = 20
Nplot = Nint

plt.subplot(1, 2, 1)
plt.title('NNet')
plt.plot(XnormTrainData[:Nplot,ix,0], XnormTrainData[:Nplot,ix,1])
plt.plot(Xpred_int[:Nplot,ix,0], Xpred_int[:Nplot,ix,1], 'k')
plt.plot(Xpred_Vq[:Nplot,ix,0], Xpred_Vq[:Nplot,ix,1], 'g--')
plt.xlabel('$x$')
plt.ylabel('$\dot{x}$')
#plt.legend(['Truth', 'Prediction', 'Linear Project NNet'])

plt.subplot(1, 2, 2)
plt.title('ROM')
plt.plot(XnormTrainData[:Nplot,ix,0], XnormTrainData[:Nplot,ix,1])
plt.plot(Qrr.T[start:Nplot+start,ix], dQrr[start:Nplot+start,ix], 'k')
plt.xlabel('$x$')
plt.ylabel('$\dot{x}$')
plt.legend(['Truth', 'Prediction'])

plt.tight_layout()
plt.show()
In [87]:
fig = plt.figure(figsize = (16,14))

np.random.seed(190)
ix = np.random.randint(0,64,3)

Nplot = Ntt*2-10

ax = fig.add_subplot(1, 2, 1, projection='3d')
ax.plot3D(XnormTrainData[:Nplot,ix[0],0], XnormTrainData[:Nplot,ix[1],0], XnormTrainData[:Nplot,ix[0],1], linewidth = 2)
ax.plot3D(Xpred_int[:Nplot,ix[0],0], Xpred_int[:Nplot,ix[1],0], Xpred_int[:Nplot,ix[0],1], 'k--', linewidth = 1)
#ax.plot3D(delta_x_hat[tmin:Nint], delta_xdot_hat[tmin:Nint], delta_Q_hat[tmin:Nint], 'k--')
plt.xlabel('$x_{{{}}}$'.format(ix[0]), fontsize = 28)
plt.ylabel('$x_{{{}}}$'.format(ix[1]), fontsize = 28)
ax.set_zlabel('$\dot{{x}}_{{{}}}$'.format(ix[0]), fontsize = 28)

ax = fig.add_subplot(1, 2, 2, projection='3d')
ax.plot3D(XnormTrainData[:Nplot,ix[0],0], XnormTrainData[:Nplot,ix[1],0], XnormTrainData[:Nplot,ix[0],1], linewidth = 2)
ax.plot3D(Qrr.T[:Nplot,ix[0]], Qrr.T[:Nplot,ix[1]], dQrr[:Nplot,ix[0]], 'k--', linewidth = 1)
plt.xlabel('$x_{{{}}}$'.format(ix[0]), fontsize = 28)
plt.ylabel('$x_{{{}}}$'.format(ix[1]), fontsize = 28)
ax.set_zlabel('$\dot{{x}}_{{{}}}$'.format(ix[0]), fontsize = 28)

plt.tight_layout()
plt.show()
In [88]:
fig = plt.figure(figsize = (16,14))

np.random.seed(190)
ix = np.random.randint(0,64,3)

Nplot = Ntt*2-10

ax = fig.add_subplot(1, 2, 1, projection='3d')
ax.plot3D(XnormTrainData[:Nplot,ix[0],0], XnormTrainData[:Nplot,ix[1],0], XnormTrainData[:Nplot,ix[0],1], linewidth = 2)
ax.plot3D(Xpred_int[:Nplot,ix[0],0], Xpred_int[:Nplot,ix[1],0], Xpred_int[:Nplot,ix[0],1], 'k--', linewidth = 1)
ax.plot3D(Xpred_Vq[:Nplot,ix[0],0], Xpred_Vq[:Nplot,ix[1],0], Xpred_Vq[:Nplot,ix[0],1], 'g--', linewidth = 1)
#ax.plot3D(delta_x_hat[tmin:Nint], delta_xdot_hat[tmin:Nint], delta_Q_hat[tmin:Nint], 'k--')
plt.xlabel('$x_{{{}}}$'.format(ix[0]), fontsize = 28)
plt.ylabel('$x_{{{}}}$'.format(ix[1]), fontsize = 28)
ax.set_zlabel('$\dot{{x}}_{{{}}}$'.format(ix[0]), fontsize = 28)

ax = fig.add_subplot(1, 2, 2, projection='3d')
ax.plot3D(XnormTrainData[:Nplot,ix[0],0], XnormTrainData[:Nplot,ix[1],0], XnormTrainData[:Nplot,ix[0],1], linewidth = 2)
ax.plot3D(Qrr.T[:Nplot,ix[0]], Qrr.T[:Nplot,ix[1]], dQrr[:Nplot,ix[0]], 'k--', linewidth = 1)
plt.xlabel('$x_{{{}}}$'.format(ix[0]), fontsize = 28)
plt.ylabel('$x_{{{}}}$'.format(ix[1]), fontsize = 28)
ax.set_zlabel('$\dot{{x}}_{{{}}}$'.format(ix[0]), fontsize = 28)

plt.tight_layout()
plt.show()
In [ ]: